diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile
index 4adb7ab5c..f64446e7e 100644
--- a/.github/scripts/docker/Dockerfile
+++ b/.github/scripts/docker/Dockerfile
@@ -36,7 +36,9 @@ RUN pip install --no-cache-dir \
\
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \
+ cython \
dill \
+ espnet_tts_frontend \
graphviz \
kaldi-decoder \
kaldi_native_io \
@@ -45,13 +47,15 @@ RUN pip install --no-cache-dir \
kaldilm \
matplotlib \
multi_quantization \
+ numba \
numpy \
onnx \
onnxmltools \
onnxruntime \
+ piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html \
+ pypinyin==0.50.0 \
pytest \
sentencepiece>=0.1.96 \
- pypinyin==0.50.0 \
six \
tensorboard \
typeguard
diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py
index 7bb8ac676..675e37c37 100755
--- a/.github/scripts/docker/generate_build_matrix.py
+++ b/.github/scripts/docker/generate_build_matrix.py
@@ -45,7 +45,7 @@ def get_torchaudio_version(torch_version):
def get_matrix():
k2_version = "1.24.4.dev20240223"
kaldifeat_version = "1.25.4.dev20240223"
- version = "20240223"
+ version = "20240318"
python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"]
torch_version = ["1.13.0", "1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.1", "2.1.2"]
torch_version += ["2.2.0", "2.2.1"]
diff --git a/.github/scripts/librispeech/ASR/run.sh b/.github/scripts/librispeech/ASR/run.sh
index 293ed66e5..b4450afea 100755
--- a/.github/scripts/librispeech/ASR/run.sh
+++ b/.github/scripts/librispeech/ASR/run.sh
@@ -64,6 +64,46 @@ function run_diagnostics() {
--print-diagnostics 1
}
+function test_streaming_zipformer_ctc_hlg() {
+ repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18
+
+ log "Downloading pre-trained model from $repo_url"
+ git lfs install
+ git clone $repo_url
+ repo=$(basename $repo_url)
+
+ rm $repo/exp-ctc-rnnt-small/*.onnx
+ ls -lh $repo/exp-ctc-rnnt-small
+
+ # export models to onnx
+ ./zipformer/export-onnx-streaming-ctc.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 3 \
+ --exp-dir $repo/exp-ctc-rnnt-small \
+ --causal 1 \
+ --use-ctc 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ \
+ --num-encoder-layers 2,2,2,2,2,2 \
+ --feedforward-dim 512,768,768,768,768,768 \
+ --encoder-dim 192,256,256,256,256,256 \
+ --encoder-unmasked-dim 192,192,192,192,192,192
+
+ ls -lh $repo/exp-ctc-rnnt-small
+
+ for wav in 0.wav 1.wav 8k.wav; do
+ python3 ./zipformer/onnx_pretrained_ctc_HLG_streaming.py \
+ --nn-model $repo/exp-ctc-rnnt-small/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
+ --words $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.fst \
+ $repo/test_wavs/$wav
+ done
+
+ rm -rf $repo
+}
+
function test_pruned_transducer_stateless_2022_03_12() {
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
@@ -1577,6 +1617,7 @@ function test_transducer_bpe_500_2021_12_23() {
prepare_data
run_diagnostics
+test_streaming_zipformer_ctc_hlg
test_pruned_transducer_stateless_2022_03_12
test_pruned_transducer_stateless2_2022_04_29
test_pruned_transducer_stateless3_2022_04_29
diff --git a/docs/source/recipes/Finetune/adapter/finetune_adapter.rst b/docs/source/recipes/Finetune/adapter/finetune_adapter.rst
new file mode 100644
index 000000000..a94b008f6
--- /dev/null
+++ b/docs/source/recipes/Finetune/adapter/finetune_adapter.rst
@@ -0,0 +1,225 @@
+Finetune from a pre-trained Zipformer model with adapters
+=========================================================
+
+This tutorial shows you how to fine-tune a pre-trained **Zipformer**
+transducer model on a new dataset with adapters.
+Adapters are compact and efficient module that can be integrated into a pre-trained model
+to improve the model's performance on a new domain. Adapters are injected
+between different modules in the well-trained neural network. During training, only the parameters
+in the adapters will be updated. It achieves competitive performance
+while requiring much less GPU memory than full fine-tuning. For more details about adapters,
+please refer to the original `paper `_ for more details.
+
+.. HINT::
+
+ We assume you have read the page :ref:`install icefall` and have setup
+ the environment for ``icefall``.
+
+.. HINT::
+
+ We recommend you to use a GPU or several GPUs to run this recipe
+
+For illustration purpose, we fine-tune the Zipformer transducer model
+pre-trained on `LibriSpeech`_ on the small subset of `GigaSpeech`_. You could use your
+own data for fine-tuning if you create a manifest for your new dataset.
+
+Data preparation
+----------------
+
+Please follow the instructions in the `GigaSpeech recipe `_
+to prepare the fine-tune data used in this tutorial. We only require the small subset in GigaSpeech for this tutorial.
+
+
+Model preparation
+-----------------
+
+We are using the Zipformer model trained on full LibriSpeech (960 hours) as the intialization. The
+checkpoint of the model can be downloaded via the following command:
+
+.. code-block:: bash
+
+ $ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+ $ cd icefall-asr-librispeech-zipformer-2023-05-15/exp
+ $ git lfs pull --include "pretrained.pt"
+ $ ln -s pretrained.pt epoch-99.pt
+ $ cd ../data/lang_bpe_500
+ $ git lfs pull --include bpe.model
+ $ cd ../../..
+
+Before fine-tuning, let's test the model's WER on the new domain. The following command performs
+decoding on the GigaSpeech test sets:
+
+.. code-block:: bash
+
+ ./zipformer/decode_gigaspeech.py \
+ --epoch 99 \
+ --avg 1 \
+ --exp-dir icefall-asr-librispeech-zipformer-2023-05-15/exp \
+ --use-averaged-model 0 \
+ --max-duration 1000 \
+ --decoding-method greedy_search
+
+You should see the following numbers:
+
+.. code-block::
+
+ For dev, WER of different settings are:
+ greedy_search 20.06 best for dev
+
+ For test, WER of different settings are:
+ greedy_search 19.27 best for test
+
+
+Fine-tune with adapter
+----------------------
+
+We insert 4 adapters with residual connection in each ``Zipformer2EncoderLayer``.
+The original model parameters remain untouched during training and only the parameters of
+the adapters are updated. The following command starts a fine-tuning experiment with adapters:
+
+.. code-block:: bash
+
+ $ do_finetune=1
+ $ use_adapters=1
+ $ adapter_dim=8
+
+ $ ./zipformer_adapter/train.py \
+ --world-size 2 \
+ --num-epochs 20 \
+ --start-epoch 1 \
+ --exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
+ --use-fp16 1 \
+ --base-lr 0.045 \
+ --use-adapters $use_adapters --adapter-dim $adapter_dim \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --do-finetune $do_finetune \
+ --master-port 13022 \
+ --finetune-ckpt icefall-asr-librispeech-zipformer-2023-05-15/exp/pretrained.pt \
+ --max-duration 1000
+
+The following arguments are related to fine-tuning:
+
+- ``--do-finetune``
+ If True, do fine-tuning by initializing the model from a pre-trained checkpoint.
+ **Note that if you want to resume your fine-tuning experiment from certain epochs, you
+ need to set this to False.**
+
+- ``use-adapters``
+ If adapters are used during fine-tuning.
+
+- ``--adapter-dim``
+ The bottleneck dimension of the adapter module. Typically a small number.
+
+You should notice that in the training log, the total number of trainale parameters is shown:
+
+.. code-block::
+
+ 2024-02-22 21:22:03,808 INFO [train.py:1277] A total of 761344 trainable parameters (1.148% of the whole model)
+
+The trainable parameters only makes up 1.15% of the entire model parameters, so the training will be much faster
+and requires less memory than full fine-tuning.
+
+
+Decoding
+--------
+
+After training, let's test the WERs. To test the WERs on the GigaSpeech set,
+you can execute the following command:
+
+.. code-block:: bash
+
+ $ epoch=20
+ $ avg=10
+ $ use_adapters=1
+ $ adapter_dim=8
+
+ % ./zipformer/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --use-averaged-model 1 \
+ --exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
+ --max-duration 600 \
+ --use-adapters $use_adapters \
+ --adapter-dim $adapter_dim \
+ --decoding-method greedy_search
+
+You should see the following numbers:
+
+.. code-block::
+
+ For dev, WER of different settings are:
+ greedy_search 15.44 best for dev
+
+ For test, WER of different settings are:
+ greedy_search 15.42 best for test
+
+
+The WER on test set is improved from 19.27 to 15.42, demonstrating the effectiveness of adapters.
+
+The same model can be used to perform decoding on LibriSpeech test sets. You can deactivate the adapters
+to keep the same performance of the original model:
+
+.. code-block:: bash
+
+ $ epoch=20
+ $ avg=1
+ $ use_adapters=0
+ $ adapter_dim=8
+
+ % ./zipformer/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --use-averaged-model 1 \
+ --exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
+ --max-duration 600 \
+ --use-adapters $use_adapters \
+ --adapter-dim $adapter_dim \
+ --decoding-method greedy_search
+
+
+.. code-block::
+
+ For dev, WER of different settings are:
+ greedy_search 2.23 best for test-clean
+
+ For test, WER of different settings are:
+ greedy_search 4.96 best for test-other
+
+The numbers are the same as reported in `icefall `_. So adapter-based
+fine-tuning is also very flexible as the same model can be used for decoding on the original and target domain.
+
+
+Export the model
+----------------
+
+After training, the model can be exported to ``onnx`` format easily using the following command:
+
+.. code-block:: bash
+
+ $ use_adapters=1
+ $ adapter_dim=16
+
+ $ ./zipformer_adapter/export-onnx.py \
+ --tokens icefall-asr-librispeech-zipformer-2023-05-15/data/lang_bpe_500/tokens.txt \
+ --use-averaged-model 1 \
+ --epoch 20 \
+ --avg 10 \
+ --exp-dir zipformer_adapter/exp_giga_finetune_adapters${use_adapters}_adapter_dim${adapter_dim} \
+ --use-adapters $use_adapters \
+ --adapter-dim $adapter_dim \
+ --num-encoder-layers "2,2,3,4,3,2" \
+ --downsampling-factor "1,2,4,8,4,2" \
+ --feedforward-dim "512,768,1024,1536,1024,768" \
+ --num-heads "4,4,4,8,4,4" \
+ --encoder-dim "192,256,384,512,384,256" \
+ --query-head-dim 32 \
+ --value-head-dim 12 \
+ --pos-head-dim 4 \
+ --pos-dim 48 \
+ --encoder-unmasked-dim "192,192,256,256,256,192" \
+ --cnn-module-kernel "31,31,15,15,15,31" \
+ --decoder-dim 512 \
+ --joiner-dim 512 \
+ --causal False \
+ --chunk-size "16,32,64,-1" \
+ --left-context-frames "64,128,256,-1"
\ No newline at end of file
diff --git a/docs/source/recipes/Finetune/index.rst b/docs/source/recipes/Finetune/index.rst
index e62b8980f..7f36d2687 100644
--- a/docs/source/recipes/Finetune/index.rst
+++ b/docs/source/recipes/Finetune/index.rst
@@ -13,3 +13,4 @@ data to improve the performance on new domains.
:caption: Table of Contents
from_supervised/finetune_zipformer
+ adapter/finetune_adapter
diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py
index 74a7b5933..2cb476e20 100755
--- a/egs/aishell/ASR/conformer_ctc/decode.py
+++ b/egs/aishell/ASR/conformer_ctc/decode.py
@@ -419,7 +419,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
@@ -432,7 +432,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=enable_log,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/conformer_mmi/decode.py b/egs/aishell/ASR/conformer_mmi/decode.py
index 20a855e7f..8a2daa93e 100755
--- a/egs/aishell/ASR/conformer_mmi/decode.py
+++ b/egs/aishell/ASR/conformer_mmi/decode.py
@@ -431,7 +431,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
@@ -444,7 +444,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=enable_log,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
index fb6c7c481..f41ea6776 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py
@@ -390,7 +390,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -402,7 +402,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
index 27c64efaa..3901a330c 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py
@@ -526,7 +526,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -538,7 +538,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py
index 696eea906..d50bccf82 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py
@@ -444,7 +444,7 @@ def save_results(
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
- store_transcripts(filename=recog_path, texts=results_char)
+ store_transcripts(filename=recog_path, texts=results_char, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -452,7 +452,11 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py
index 6027273b2..058d0ff6b 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py
@@ -89,6 +89,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@@ -881,9 +882,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error()
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py
index 9d9dd4288..2dc835f3b 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/train.py
@@ -85,6 +85,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import (
@@ -878,9 +879,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
index da9000164..46f542641 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/decode.py
@@ -581,7 +581,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -594,7 +594,11 @@ def save_results(
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py
index 3858bafd7..811269989 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py
@@ -78,6 +78,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -871,9 +872,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
index 0e783e92b..61b929091 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
@@ -492,7 +492,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -500,7 +500,11 @@ def save_results(
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index 0fba3b58f..6653d9d9c 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -78,6 +78,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -882,9 +883,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py
index 2e1044658..f3b0f1e11 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py
@@ -78,6 +78,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -881,9 +882,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
index 824ca2a92..05e52f560 100755
--- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
+++ b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py
@@ -278,7 +278,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -289,7 +289,13 @@ def save_results(
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
- wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
+ wer = write_error_stats(
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
+ )
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py
index d23f4f883..d958a6338 100755
--- a/egs/aishell/ASR/transducer_stateless/decode.py
+++ b/egs/aishell/ASR/transducer_stateless/decode.py
@@ -327,7 +327,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
@@ -338,7 +338,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
index d164b6890..57f7a8239 100755
--- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py
@@ -372,7 +372,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -384,7 +384,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py
index 0a7d87fe8..56f3724eb 100755
--- a/egs/aishell/ASR/transducer_stateless_modified/decode.py
+++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py
@@ -376,7 +376,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -388,7 +388,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py
index 7f841dcb7..c632d0757 100755
--- a/egs/aishell/ASR/whisper/decode.py
+++ b/egs/aishell/ASR/whisper/decode.py
@@ -358,7 +358,7 @@ def save_results(
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
@@ -373,7 +373,11 @@ def save_results(
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
+ f,
+ f"{test_set_name}-{key}",
+ results_char,
+ enable_log=enable_log,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py
index 073b23713..6ccb8d363 100755
--- a/egs/aishell/ASR/whisper/train.py
+++ b/egs/aishell/ASR/whisper/train.py
@@ -793,7 +793,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2**22
+ 512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diff --git a/egs/aishell/ASR/zipformer/decode.py b/egs/aishell/ASR/zipformer/decode.py
index 1968904ae..538189e52 100755
--- a/egs/aishell/ASR/zipformer/decode.py
+++ b/egs/aishell/ASR/zipformer/decode.py
@@ -560,7 +560,7 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
- store_transcripts(filename=recog_path, texts=results)
+ store_transcripts(filename=recog_path, texts=results, char_level=True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@@ -570,7 +570,11 @@ def save_results(
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
- f, f"{test_set_name}-{key}", results, enable_log=True
+ f,
+ f"{test_set_name}-{key}",
+ results,
+ enable_log=True,
+ compute_CER=True,
)
test_set_wers[key] = wer
diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py
index d381649e4..a25979226 100755
--- a/egs/aishell/ASR/zipformer/train.py
+++ b/egs/aishell/ASR/zipformer/train.py
@@ -86,6 +86,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import (
@@ -985,9 +986,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py
index a2bf96b29..0713c5787 100755
--- a/egs/aishell/ASR/zipformer/train_bbpe.py
+++ b/egs/aishell/ASR/zipformer/train_bbpe.py
@@ -83,6 +83,7 @@ from icefall.checkpoint import (
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -570,9 +571,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py
index 8f09f1aa5..30879d8d2 100755
--- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py
+++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py
@@ -70,6 +70,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -851,9 +852,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py
index 9b67141c0..d62cdadb7 100755
--- a/egs/ami/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py
@@ -69,6 +69,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -842,9 +843,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/ami/SURT/dprnn_zipformer/train.py b/egs/ami/SURT/dprnn_zipformer/train.py
index cd5fafc34..adc6a8495 100755
--- a/egs/ami/SURT/dprnn_zipformer/train.py
+++ b/egs/ami/SURT/dprnn_zipformer/train.py
@@ -75,6 +75,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -1138,9 +1139,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/ami/SURT/dprnn_zipformer/train_adapt.py b/egs/ami/SURT/dprnn_zipformer/train_adapt.py
index 9f3b4425f..ac5b0dadc 100755
--- a/egs/ami/SURT/dprnn_zipformer/train_adapt.py
+++ b/egs/ami/SURT/dprnn_zipformer/train_adapt.py
@@ -75,6 +75,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -1129,9 +1130,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py
index 4aedeffe4..4957c0c31 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py
@@ -79,6 +79,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -871,9 +872,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index 0426bc9a3..a3f387636 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -889,9 +889,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise RuntimeError(f", exiting: {cur_grad_scale}")
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
@@ -1037,7 +1035,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2**22
+ 512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py
index 3a10c5d81..81c69e5e0 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py
@@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -965,9 +966,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
@@ -1120,7 +1119,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2**22
+ 512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py
index a9bc9c2a2..728104580 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py
@@ -78,6 +78,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -888,9 +889,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
@@ -1036,7 +1035,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2**22
+ 512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index 685f6ece6..6d256308c 100755
--- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
+++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -909,9 +910,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py
index 73fcd67aa..ef7ea9013 100755
--- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py
+++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py
@@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -908,9 +909,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py
index c5335562c..f0ad98147 100755
--- a/egs/gigaspeech/ASR/zipformer/train.py
+++ b/egs/gigaspeech/ASR/zipformer/train.py
@@ -89,6 +89,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -1031,9 +1032,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/gigaspeech/KWS/zipformer/finetune.py b/egs/gigaspeech/KWS/zipformer/finetune.py
index 2cd7c868b..a7ba56127 100755
--- a/egs/gigaspeech/KWS/zipformer/finetune.py
+++ b/egs/gigaspeech/KWS/zipformer/finetune.py
@@ -100,6 +100,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -371,9 +372,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py
index e7387dd39..a4d670169 100755
--- a/egs/gigaspeech/KWS/zipformer/train.py
+++ b/egs/gigaspeech/KWS/zipformer/train.py
@@ -89,6 +89,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -1034,9 +1035,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/libricss/SURT/dprnn_zipformer/train.py b/egs/libricss/SURT/dprnn_zipformer/train.py
index 6598f8b5d..90d742e7c 100755
--- a/egs/libricss/SURT/dprnn_zipformer/train.py
+++ b/egs/libricss/SURT/dprnn_zipformer/train.py
@@ -85,6 +85,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -1169,9 +1170,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py
index 1c1b0c28c..8c37430ec 100755
--- a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py
+++ b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py
@@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@@ -1056,9 +1057,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py
index c97da4a11..8d4d9d067 100644
--- a/egs/libriheavy/ASR/zipformer/train.py
+++ b/egs/libriheavy/ASR/zipformer/train.py
@@ -93,6 +93,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -1036,9 +1037,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py
index c8b20d021..93f7e1248 100644
--- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py
+++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py
@@ -103,6 +103,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -1051,9 +1052,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py
index 9822b99c1..2a2c206aa 100755
--- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py
+++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py
@@ -117,6 +117,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -855,9 +856,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
- context_dim=4 * 768
- if params.context_injection
- else -1, # the output dim of text encoder
+ context_dim=(
+ 4 * 768 if params.context_injection else -1
+ ), # the output dim of text encoder
context_injection=params.context_injection,
)
return joiner
@@ -1398,9 +1399,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md
index 5c5a76917..080f81c91 100644
--- a/egs/librispeech/ASR/README.md
+++ b/egs/librispeech/ASR/README.md
@@ -36,6 +36,7 @@ The following table lists the differences among them.
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | It supports domain adaptation of Zipformer using parameter efficient adapters |
+| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | Finetune Zipformer with LoRA |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py
index a7a8ef149..e7546ec45 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py
@@ -80,6 +80,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -976,9 +977,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
index fac3706d2..436ec53b4 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py
@@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -878,9 +879,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
index d8fa08372..b35e56abc 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py
@@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -902,9 +903,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py
index 25a1aa674..c2d877a93 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py
@@ -77,6 +77,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -891,9 +892,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index 9a6d2155b..8e239e322 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -80,6 +80,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -880,9 +881,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py
index e1bdce49d..8bd00bbef 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py
@@ -80,6 +80,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -879,9 +880,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py
index 1642ef4b7..da5e144c9 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py
@@ -84,6 +84,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -946,9 +947,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py
index 3f271c5b4..646f30ca1 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py
@@ -89,6 +89,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -946,9 +947,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
index 2ab051e83..814390ad6 100644
--- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
+++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
@@ -479,18 +479,14 @@ class LibriSpeechAsrDataModule:
@lru_cache()
def gigaspeech_subset_small_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech subset-S cuts")
- return load_manifest_lazy(self.args.manifest_dir / "gigaspeech_cuts_S.jsonl.gz")
+ return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz")
@lru_cache()
def gigaspeech_dev_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech dev cuts")
- return load_manifest_lazy(
- self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
- )
+ return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
@lru_cache()
def gigaspeech_test_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech test cuts")
- return load_manifest_lazy(
- self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
- )
+ return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py
index 8920764cd..1bfd071de 100644
--- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py
+++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py
@@ -66,6 +66,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import UniqLexicon
from icefall.utils import (
@@ -883,9 +884,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
index 3c0f74005..1eba6093b 100755
--- a/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
+++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming-ctc.py
@@ -32,7 +32,7 @@ This script exports a CTC model from PyTorch to ONNX.
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
- --left-context-frames 64 \
+ --left-context-frames 128 \
--use-ctc 1
The --chunk-size in training is "16,32,64,-1", so we select one of them
@@ -41,7 +41,7 @@ whose value is "64,128,256,-1".
It will generate the following file inside $repo/exp:
- - ctc-epoch-99-avg-1-chunk-16-left-64.onnx
+ - ctc-epoch-99-avg-1-chunk-16-left-128.onnx
See ./onnx_pretrained-streaming-ctc.py for how to use the exported ONNX models.
"""
diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py
index 6bc9b1858..5d0c9ea43 100755
--- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py
+++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py
@@ -48,7 +48,7 @@ popd
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
- --left-context-frames 64
+ --left-context-frames 128
The --chunk-size in training is "16,32,64,-1", so we select one of them
(excluding -1) during streaming export. The same applies to `--left-context`,
@@ -56,9 +56,9 @@ whose value is "64,128,256,-1".
It will generate the following 3 files inside $repo/exp:
- - encoder-epoch-99-avg-1-chunk-16-left-64.onnx
- - decoder-epoch-99-avg-1-chunk-16-left-64.onnx
- - joiner-epoch-99-avg-1-chunk-16-left-64.onnx
+ - encoder-epoch-99-avg-1-chunk-16-left-128.onnx
+ - decoder-epoch-99-avg-1-chunk-16-left-128.onnx
+ - joiner-epoch-99-avg-1-chunk-16-left-128.onnx
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
"""
@@ -333,6 +333,7 @@ def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
+ feature_dim: int = 80,
) -> None:
encoder_model.encoder.__class__.forward = (
encoder_model.encoder.__class__.streaming_forward
@@ -343,7 +344,7 @@ def export_encoder_model_onnx(
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
T = decode_chunk_len + encoder_model.pad_length
- x = torch.rand(1, T, 80, dtype=torch.float32)
+ x = torch.rand(1, T, feature_dim, dtype=torch.float32)
init_state = encoder_model.get_init_states()
num_encoders = len(encoder_model.encoder.encoder_dim)
logging.info(f"num_encoders: {num_encoders}")
@@ -724,6 +725,7 @@ def main():
encoder,
encoder_filename,
opset_version=opset_version,
+ feature_dim=params.feature_dim,
)
logging.info(f"Exported encoder to {encoder_filename}")
diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py
index 843d103cc..2f7ec0c17 100755
--- a/egs/librispeech/ASR/zipformer/finetune.py
+++ b/egs/librispeech/ASR/zipformer/finetune.py
@@ -92,6 +92,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -1122,9 +1123,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py
new file mode 100755
index 000000000..a8b08de34
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py
@@ -0,0 +1,439 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
+# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
+
+"""
+This script loads ONNX models exported by ./export-onnx-streaming-ctc.py
+and uses them to decode waves.
+
+We use the pre-trained model from
+https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18
+as an example to show how to use this file.
+
+1. Download the pre-trained model
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "exp-ctc-rnnt-small/*.pt"
+git lfs pull --include "data/lang_bpe_500/words.txt"
+git lfs pull --include "data/lang_bpe_500/HLG.fst"
+popd
+
+2. Export the model to ONNX
+
+./zipformer/export-onnx-streaming-ctc.py \
+ --tokens $repo/data/lang_bpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 3 \
+ --exp-dir $repo/exp-ctc-rnnt-small \
+ --causal 1 \
+ --use-ctc 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ \
+ --num-encoder-layers 2,2,2,2,2,2 \
+ --feedforward-dim 512,768,768,768,768,768 \
+ --encoder-dim 192,256,256,256,256,256 \
+ --encoder-unmasked-dim 192,192,192,192,192,192
+
+It will generate the following 2 files inside $repo/exp-ctc-rnnt-small:
+
+ - ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx
+ - ctc-epoch-30-avg-3-chunk-16-left-128.onnx
+
+You can use either the ``int8.onnx`` model or just the ``.onnx`` model.
+
+3. Run this file with the exported ONNX models
+
+python3 ./zipformer/onnx_pretrained_ctc_HLG_streaming.py \
+ --nn-model $repo/exp-ctc-rnnt-small/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
+ --words $repo/data/lang_bpe_500/words.txt \
+ --HLG $repo/data/lang_bpe_500/HLG.fst \
+ $repo/test_wavs/0.wav
+
+Note: Even though this script only supports decoding a single file,
+the exported ONNX models do support batch processing.
+
+Note: HLG.fst is generated directly from ../local/prepare_lang_fst.py
+"""
+
+import argparse
+import logging
+from typing import Dict, List, Tuple
+
+import k2
+import kaldifst
+import numpy as np
+import onnxruntime as ort
+import torch
+import torchaudio
+from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions
+from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--nn-model",
+ type=str,
+ required=True,
+ help="Path to the onnx model. ",
+ )
+
+ parser.add_argument(
+ "--words",
+ type=str,
+ required=True,
+ help="""Path to words.txt.""",
+ )
+
+ parser.add_argument(
+ "--HLG",
+ type=str,
+ required=True,
+ help="""Path to HLG.fst.""",
+ )
+
+ parser.add_argument(
+ "sound_file",
+ type=str,
+ help="The input sound file to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. ",
+ )
+
+ return parser
+
+
+class OnnxModel:
+ def __init__(
+ self,
+ model_filename: str,
+ ):
+ session_opts = ort.SessionOptions()
+ session_opts.inter_op_num_threads = 1
+ session_opts.intra_op_num_threads = 1
+
+ self.session_opts = session_opts
+
+ self.init_model(model_filename)
+
+ def init_model(self, model_filename: str):
+ self.model = ort.InferenceSession(
+ model_filename,
+ sess_options=self.session_opts,
+ providers=["CPUExecutionProvider"],
+ )
+ self.init_states()
+
+ def init_states(self, batch_size: int = 1):
+ meta = self.model.get_modelmeta().custom_metadata_map
+ logging.info(f"meta={meta}")
+
+ model_type = meta["model_type"]
+ assert model_type == "zipformer2", model_type
+
+ decode_chunk_len = int(meta["decode_chunk_len"])
+ T = int(meta["T"])
+
+ num_encoder_layers = meta["num_encoder_layers"]
+ encoder_dims = meta["encoder_dims"]
+ cnn_module_kernels = meta["cnn_module_kernels"]
+ left_context_len = meta["left_context_len"]
+ query_head_dims = meta["query_head_dims"]
+ value_head_dims = meta["value_head_dims"]
+ num_heads = meta["num_heads"]
+
+ def to_int_list(s):
+ return list(map(int, s.split(",")))
+
+ num_encoder_layers = to_int_list(num_encoder_layers)
+ encoder_dims = to_int_list(encoder_dims)
+ cnn_module_kernels = to_int_list(cnn_module_kernels)
+ left_context_len = to_int_list(left_context_len)
+ query_head_dims = to_int_list(query_head_dims)
+ value_head_dims = to_int_list(value_head_dims)
+ num_heads = to_int_list(num_heads)
+
+ logging.info(f"decode_chunk_len: {decode_chunk_len}")
+ logging.info(f"T: {T}")
+ logging.info(f"num_encoder_layers: {num_encoder_layers}")
+ logging.info(f"encoder_dims: {encoder_dims}")
+ logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
+ logging.info(f"left_context_len: {left_context_len}")
+ logging.info(f"query_head_dims: {query_head_dims}")
+ logging.info(f"value_head_dims: {value_head_dims}")
+ logging.info(f"num_heads: {num_heads}")
+
+ num_encoders = len(num_encoder_layers)
+
+ self.states = []
+ for i in range(num_encoders):
+ num_layers = num_encoder_layers[i]
+ key_dim = query_head_dims[i] * num_heads[i]
+ embed_dim = encoder_dims[i]
+ nonlin_attn_head_dim = 3 * embed_dim // 4
+ value_dim = value_head_dims[i] * num_heads[i]
+ conv_left_pad = cnn_module_kernels[i] // 2
+
+ for layer in range(num_layers):
+ cached_key = torch.zeros(
+ left_context_len[i], batch_size, key_dim
+ ).numpy()
+ cached_nonlin_attn = torch.zeros(
+ 1, batch_size, left_context_len[i], nonlin_attn_head_dim
+ ).numpy()
+ cached_val1 = torch.zeros(
+ left_context_len[i], batch_size, value_dim
+ ).numpy()
+ cached_val2 = torch.zeros(
+ left_context_len[i], batch_size, value_dim
+ ).numpy()
+ cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
+ cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
+ self.states += [
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ]
+ embed_states = torch.zeros(batch_size, 128, 3, 19).numpy()
+ self.states.append(embed_states)
+ processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy()
+ self.states.append(processed_lens)
+
+ self.num_encoders = num_encoders
+
+ self.segment = T
+ self.offset = decode_chunk_len
+
+ def _build_model_input_output(
+ self,
+ x: torch.Tensor,
+ ) -> Tuple[Dict[str, np.ndarray], List[str]]:
+ model_input = {"x": x.numpy()}
+ model_output = ["log_probs"]
+
+ def build_inputs_outputs(tensors, i):
+ assert len(tensors) == 6, len(tensors)
+
+ # (downsample_left, batch_size, key_dim)
+ name = f"cached_key_{i}"
+ model_input[name] = tensors[0]
+ model_output.append(f"new_{name}")
+
+ # (1, batch_size, downsample_left, nonlin_attn_head_dim)
+ name = f"cached_nonlin_attn_{i}"
+ model_input[name] = tensors[1]
+ model_output.append(f"new_{name}")
+
+ # (downsample_left, batch_size, value_dim)
+ name = f"cached_val1_{i}"
+ model_input[name] = tensors[2]
+ model_output.append(f"new_{name}")
+
+ # (downsample_left, batch_size, value_dim)
+ name = f"cached_val2_{i}"
+ model_input[name] = tensors[3]
+ model_output.append(f"new_{name}")
+
+ # (batch_size, embed_dim, conv_left_pad)
+ name = f"cached_conv1_{i}"
+ model_input[name] = tensors[4]
+ model_output.append(f"new_{name}")
+
+ # (batch_size, embed_dim, conv_left_pad)
+ name = f"cached_conv2_{i}"
+ model_input[name] = tensors[5]
+ model_output.append(f"new_{name}")
+
+ for i in range(len(self.states[:-2]) // 6):
+ build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i)
+
+ # (batch_size, channels, left_pad, freq)
+ name = "embed_states"
+ embed_states = self.states[-2]
+ model_input[name] = embed_states
+ model_output.append(f"new_{name}")
+
+ # (batch_size,)
+ name = "processed_lens"
+ processed_lens = self.states[-1]
+ model_input[name] = processed_lens
+ model_output.append(f"new_{name}")
+
+ return model_input, model_output
+
+ def _update_states(self, states: List[np.ndarray]):
+ self.states = states
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x:
+ A 3-D tensor of shape (N, T, C)
+ Returns:
+ Return a 3-D tensor containing log_probs. Its shape is (N, T, vocab_size)
+ where T' is usually equal to ((T-7)//2 - 3)//2
+ """
+ model_input, model_output_names = self._build_model_input_output(x)
+
+ out = self.model.run(model_output_names, model_input)
+
+ self._update_states(out[1:])
+
+ return torch.from_numpy(out[0])
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ if sample_rate != expected_sample_rate:
+ logging.info(f"Resample {sample_rate} to {expected_sample_rate}")
+ wave = torchaudio.functional.resample(
+ wave,
+ orig_freq=sample_rate,
+ new_freq=expected_sample_rate,
+ )
+ # We use only the first channel
+ ans.append(wave[0].contiguous())
+ return ans
+
+
+def create_streaming_feature_extractor() -> OnlineFeature:
+ """Create a CPU streaming feature extractor.
+
+ At present, we assume it returns a fbank feature extractor with
+ fixed options. In the future, we will support passing in the options
+ from outside.
+
+ Returns:
+ Return a CPU streaming feature extractor.
+ """
+ opts = FbankOptions()
+ opts.device = "cpu"
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = 16000
+ opts.mel_opts.num_bins = 80
+ opts.mel_opts.high_freq = -400
+ return OnlineFbank(opts)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ logging.info(vars(args))
+
+ word_table = k2.SymbolTable.from_file(args.words)
+ model = OnnxModel(model_filename=args.nn_model)
+
+ sample_rate = 16000
+
+ logging.info("Constructing Fbank computer")
+ online_fbank = create_streaming_feature_extractor()
+
+ logging.info(f"Reading sound files: {args.sound_file}")
+ waves = read_sound_files(
+ filenames=[args.sound_file],
+ expected_sample_rate=sample_rate,
+ )[0]
+
+ tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
+ wave_samples = torch.cat([waves, tail_padding])
+
+ num_processed_frames = 0
+ segment = model.segment
+ offset = model.offset
+
+ logging.info(f"Loading HLG from {args.HLG}")
+ HLG = kaldifst.StdVectorFst.read(args.HLG)
+
+ decoder_opts = FasterDecoderOptions(max_active=3000)
+ decoder = FasterDecoder(HLG, decoder_opts)
+ decoder.init_decoding()
+
+ chunk = int(1 * sample_rate) # 1 second
+ start = 0
+
+ n = 0
+ while start < wave_samples.numel():
+ end = min(start + chunk, wave_samples.numel())
+
+ # simulate streaming
+ samples = wave_samples[start:end]
+ start += chunk
+
+ online_fbank.accept_waveform(
+ sampling_rate=sample_rate,
+ waveform=samples,
+ )
+
+ while online_fbank.num_frames_ready - num_processed_frames >= segment:
+ frames = []
+ for i in range(segment):
+ frames.append(online_fbank.get_frame(num_processed_frames + i))
+
+ frames = torch.cat(frames, dim=0)
+ frames = frames.unsqueeze(0)
+
+ log_probs = model(frames)
+ log_probs = log_probs.squeeze(0).cpu().numpy()
+
+ decodable = DecodableCtc(log_probs, offset=n)
+ n += log_probs.shape[0]
+
+ num_processed_frames += offset
+ decoder.advance_decoding(decodable)
+
+ if not decoder.reached_final():
+ logging.info(f"Failed to decode {args.sound_file}")
+ return
+
+ ok, best_path = decoder.get_best_path()
+
+ (
+ ok,
+ isymbols_out,
+ osymbols_out,
+ total_weight,
+ ) = kaldifst.get_linear_symbol_sequence(best_path)
+
+ if not ok:
+ logging.info(f"Failed to get linear symbol sequence for {args.sound_file}")
+ return
+
+ hyps = " ".join([word_table[i] for i in osymbols_out]).lower()
+ logging.info(f"\n{args.sound_file}\n{hyps}")
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py
index 3ccf7d2f1..1111d32ab 100755
--- a/egs/librispeech/ASR/zipformer/train.py
+++ b/egs/librispeech/ASR/zipformer/train.py
@@ -90,6 +90,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -1021,9 +1022,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py
index e64c10e7a..6c55896a8 100755
--- a/egs/librispeech/ASR/zipformer_adapter/train.py
+++ b/egs/librispeech/ASR/zipformer_adapter/train.py
@@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -1125,9 +1126,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py
index 60990456d..60112a84e 100755
--- a/egs/librispeech/ASR/zipformer_ctc/train.py
+++ b/egs/librispeech/ASR/zipformer_ctc/train.py
@@ -62,6 +62,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -797,9 +798,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/librispeech/ASR/zipformer_lora/asr_datamodule.py b/egs/librispeech/ASR/zipformer_lora/asr_datamodule.py
new file mode 120000
index 000000000..fa1b8cca3
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/asr_datamodule.py
@@ -0,0 +1 @@
+../tdnn_lstm_ctc/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/beam_search.py b/egs/librispeech/ASR/zipformer_lora/beam_search.py
new file mode 120000
index 000000000..8554e44cc
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/beam_search.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py
new file mode 100755
index 000000000..4d93a905f
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py
@@ -0,0 +1,1115 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+ modified_beam_search_lm_rescore,
+ modified_beam_search_lm_rescore_LODR,
+ modified_beam_search_lm_shallow_fusion,
+ modified_beam_search_LODR,
+)
+from finetune import add_finetune_arguments, add_model_arguments, get_model, get_params
+
+from icefall import ContextGraph, LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+conversational_filler = [
+ "UH",
+ "UHH",
+ "UM",
+ "EH",
+ "MM",
+ "HM",
+ "AH",
+ "HUH",
+ "HA",
+ "ER",
+ "OOF",
+ "HEE",
+ "ACH",
+ "EEE",
+ "EW",
+]
+unk_tags = ["", ""]
+gigaspeech_punctuations = [
+ "",
+ "",
+ "",
+ "",
+]
+gigaspeech_garbage_utterance_tags = ["", "", "", ""]
+non_scoring_words = (
+ conversational_filler
+ + unk_tags
+ + gigaspeech_punctuations
+ + gigaspeech_garbage_utterance_tags
+)
+
+
+def asr_text_post_processing(text: str) -> str:
+ # 1. convert to uppercase
+ text = text.upper()
+
+ # 2. remove hyphen
+ # "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART"
+ text = text.replace("-", " ")
+
+ # 3. remove non-scoring words from evaluation
+ remaining_words = []
+ for word in text.split():
+ if word in non_scoring_words:
+ continue
+ remaining_words.append(word)
+
+ return " ".join(remaining_words)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - modified_beam_search_LODR
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding-method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding-method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-shallow-fusion",
+ type=str2bool,
+ default=False,
+ help="""Use neural network LM for shallow fusion.
+ If you want to use LODR, you will also need to set this to true
+ """,
+ )
+
+ parser.add_argument(
+ "--lm-type",
+ type=str,
+ default="rnn",
+ help="Type of NN lm",
+ choices=["rnn", "transformer"],
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.3,
+ help="""The scale of the neural network LM
+ Used only when `--use-shallow-fusion` is set to True.
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens-ngram",
+ type=int,
+ default=2,
+ help="""The order of the ngram lm.
+ """,
+ )
+
+ parser.add_argument(
+ "--backoff-id",
+ type=int,
+ default=500,
+ help="ID of the backoff symbol in the ngram LM",
+ )
+
+ parser.add_argument(
+ "--context-score",
+ type=float,
+ default=2,
+ help="""
+ The bonus score of each token for the context biasing words/phrases.
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-file",
+ type=str,
+ default="",
+ help="""
+ The path of the context biasing lists, one word/phrase each line
+ Used only when --decoding-method is modified_beam_search and
+ modified_beam_search_LODR.
+ """,
+ )
+ add_model_arguments(parser)
+ add_finetune_arguments(parser)
+
+ return parser
+
+
+def post_processing(
+ results: List[Tuple[str, List[str], List[str]]],
+) -> List[Tuple[str, List[str], List[str]]]:
+ new_results = []
+ for key, ref, hyp in results:
+ new_ref = asr_text_post_processing(" ".join(ref)).split()
+ new_hyp = asr_text_post_processing(" ".join(hyp)).split()
+ new_results.append((key, new_ref, new_hyp))
+ return new_results
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ LM:
+ A neural network language model.
+ ngram_lm:
+ A ngram language model
+ ngram_lm_scale:
+ The scale for the ngram language model.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
+ hyp_tokens = modified_beam_search_lm_shallow_fusion(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_LODR":
+ hyp_tokens = modified_beam_search_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LODR_lm=ngram_lm,
+ LODR_lm_scale=ngram_lm_scale,
+ LM=LM,
+ context_graph=context_graph,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search_lm_rescore":
+ lm_scale_list = [0.01 * i for i in range(10, 50)]
+ ans_dict = modified_beam_search_lm_rescore(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ lm_scale_list=lm_scale_list,
+ )
+ elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ lm_scale_list = [0.02 * i for i in range(2, 30)]
+ ans_dict = modified_beam_search_lm_rescore_LODR(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ LM=LM,
+ LODR_lm=ngram_lm,
+ sp=sp,
+ lm_scale_list=lm_scale_list,
+ )
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ elif "modified_beam_search" in params.decoding_method:
+ prefix = f"beam_size_{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ ):
+ ans = dict()
+ assert ans_dict is not None
+ for key, hyps in ans_dict.items():
+ hyps = [sp.decode(hyp).split() for hyp in hyps]
+ ans[f"{prefix}_{key}"] = hyps
+ return ans
+ else:
+ if params.has_contexts:
+ prefix += f"-context-score-{params.context_score}"
+ return {prefix: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+ LM: Optional[LmScorer] = None,
+ ngram_lm=None,
+ ngram_lm_scale: float = 0.0,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ word_table=word_table,
+ batch=batch,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = post_processing(results)
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ LmScorer.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if os.path.exists(params.context_file):
+ params.has_contexts = True
+ else:
+ params.has_contexts = False
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ if params.decoding_method in (
+ "modified_beam_search",
+ "modified_beam_search_LODR",
+ ):
+ if params.has_contexts:
+ params.suffix += f"-context-score-{params.context_score}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_shallow_fusion:
+ params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
+
+ if "LODR" in params.decoding_method:
+ params.suffix += (
+ f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
+ )
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ # only load the neural network LM if required
+ if params.use_shallow_fusion or params.decoding_method in (
+ "modified_beam_search_lm_rescore",
+ "modified_beam_search_lm_rescore_LODR",
+ "modified_beam_search_lm_shallow_fusion",
+ "modified_beam_search_LODR",
+ ):
+ LM = LmScorer(
+ lm_type=params.lm_type,
+ params=params,
+ device=device,
+ lm_scale=params.lm_scale,
+ )
+ LM.to(device)
+ LM.eval()
+ else:
+ LM = None
+
+ # only load N-gram LM when needed
+ if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
+ try:
+ import kenlm
+ except ImportError:
+ print("Please install kenlm first. You can use")
+ print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
+ print("to install it")
+ import sys
+
+ sys.exit(-1)
+ ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
+ logging.info(f"lm filename: {ngram_file_name}")
+ ngram_lm = kenlm.Model(ngram_file_name)
+ ngram_lm_scale = None # use a list to search
+
+ elif params.decoding_method == "modified_beam_search_LODR":
+ lm_filename = f"{params.tokens_ngram}gram.fst.txt"
+ logging.info(f"Loading token level lm: {lm_filename}")
+ ngram_lm = NgramLm(
+ str(params.lang_dir / lm_filename),
+ backoff_id=params.backoff_id,
+ is_binary=False,
+ )
+ logging.info(f"num states: {ngram_lm.lm.num_states}")
+ ngram_lm_scale = params.ngram_lm_scale
+ else:
+ ngram_lm = None
+ ngram_lm_scale = None
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ if "modified_beam_search" in params.decoding_method:
+ if os.path.exists(params.context_file):
+ contexts = []
+ for line in open(params.context_file).readlines():
+ contexts.append((sp.encode(line.strip()), 0.0))
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(contexts)
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
+ gigaspeech_test_cuts = librispeech.gigaspeech_test_cuts()
+
+ dev_dl = librispeech.test_dataloaders(gigaspeech_dev_cuts)
+ test_dl = librispeech.test_dataloaders(gigaspeech_test_cuts)
+
+ test_sets = ["dev", "test"]
+ test_dl = [dev_dl, test_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ LM=LM,
+ ngram_lm=ngram_lm,
+ ngram_lm_scale=ngram_lm_scale,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/ASR/zipformer_lora/decoder.py b/egs/librispeech/ASR/zipformer_lora/decoder.py
new file mode 120000
index 000000000..cab465d2b
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/decoder.py
@@ -0,0 +1 @@
+../zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/encoder_interface.py b/egs/librispeech/ASR/zipformer_lora/encoder_interface.py
new file mode 120000
index 000000000..aa5d0217a
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/encoder_interface.py
@@ -0,0 +1 @@
+../transducer_stateless/encoder_interface.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/export.py b/egs/librispeech/ASR/zipformer_lora/export.py
new file mode 100755
index 000000000..d47666bef
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/export.py
@@ -0,0 +1,543 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+Note: This is a example for librispeech dataset, if you are using different
+dataset, you should change the argument values according to your dataset.
+
+(1) Export to torchscript model using torch.jit.script()
+
+- For non-streaming model:
+
+./zipformer_lora/export.py \
+ --exp-dir ./zipformer_lora/exp \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 9 \
+ --jit 1
+
+It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("jit_script.pt")`.
+
+Check ./jit_pretrained.py for its usage.
+
+Check https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+- For streaming model:
+
+./zipformer_lora/export.py \
+ --exp-dir ./zipformer_lora/exp \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 9 \
+ --jit 1
+
+It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
+You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
+
+Check ./jit_pretrained_streaming.py for its usage.
+
+Check https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+- For non-streaming model:
+
+./zipformer_lora/export.py \
+ --exp-dir ./zipformer_lora/exp \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 9
+
+- For streaming model:
+
+./zipformer_lora/export.py \
+ --exp-dir ./zipformer_lora/exp \
+ --causal 1 \
+ --tokens data/lang_bpe_500/tokens.txt \
+ --epoch 30 \
+ --avg 9
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+- For non-streaming model:
+
+To use the generated file with `zipformer_lora/decode.py`,
+you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/librispeech/ASR
+ ./zipformer_lora/decode.py \
+ --exp-dir ./zipformer_lora/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bpe_500/bpe.model
+
+- For streaming model:
+
+To use the generated file with `zipformer_lora/decode.py` and `zipformer_lora/streaming_decode.py`, you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/librispeech/ASR
+
+ # simulated streaming decoding
+ ./zipformer_lora/decode.py \
+ --exp-dir ./zipformer_lora/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bpe_500/bpe.model
+
+ # chunk-wise streaming decoding
+ ./zipformer_lora/streaming_decode.py \
+ --exp-dir ./zipformer_lora/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bpe_500/bpe.model
+
+Check ./pretrained.py for its usage.
+
+Note: If you don't want to train a model from scratch, we have
+provided one for you. You can get it at
+
+- non-streaming model:
+https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+
+- streaming model:
+https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
+
+with the following commands:
+
+ sudo apt-get install git-lfs
+ git lfs install
+ git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+ git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
+ # You will find the pre-trained models in exp dir
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import List, Tuple
+
+import k2
+import torch
+from finetune import add_finetune_arguments, add_model_arguments, get_model, get_params
+from scaling_converter import convert_scaled_to_non_scaled
+from torch import Tensor, nn
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import make_pad_mask, num_tokens, str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=9,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer_lora/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/lang_bpe_500/tokens.txt",
+ help="Path to the tokens.txt",
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ It will generate a file named jit_script.pt.
+ Check ./jit_pretrained.py for how to use it.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+ add_finetune_arguments(parser)
+
+ return parser
+
+
+class EncoderModel(nn.Module):
+ """A wrapper for encoder and encoder_embed"""
+
+ def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
+ super().__init__()
+ self.encoder = encoder
+ self.encoder_embed = encoder_embed
+
+ def forward(
+ self, features: Tensor, feature_lengths: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ features: (N, T, C)
+ feature_lengths: (N,)
+ """
+ x, x_lens = self.encoder_embed(features, feature_lengths)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ return encoder_out, encoder_out_lens
+
+
+class StreamingEncoderModel(nn.Module):
+ """A wrapper for encoder and encoder_embed"""
+
+ def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
+ super().__init__()
+ assert len(encoder.chunk_size) == 1, encoder.chunk_size
+ assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
+ self.chunk_size = encoder.chunk_size[0]
+ self.left_context_len = encoder.left_context_frames[0]
+
+ # The encoder_embed subsample features (T - 7) // 2
+ # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
+ self.pad_length = 7 + 2 * 3
+
+ self.encoder = encoder
+ self.encoder_embed = encoder_embed
+
+ def forward(
+ self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
+ ) -> Tuple[Tensor, Tensor, List[Tensor]]:
+ """Streaming forward for encoder_embed and encoder.
+
+ Args:
+ features: (N, T, C)
+ feature_lengths: (N,)
+ states: a list of Tensors
+
+ Returns encoder outputs, output lengths, and updated states.
+ """
+ chunk_size = self.chunk_size
+ left_context_len = self.left_context_len
+
+ cached_embed_left_pad = states[-2]
+ x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
+ x=features,
+ x_lens=feature_lengths,
+ cached_left_pad=cached_embed_left_pad,
+ )
+ assert x.size(1) == chunk_size, (x.size(1), chunk_size)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+
+ # processed_mask is used to mask out initial states
+ processed_mask = torch.arange(left_context_len, device=x.device).expand(
+ x.size(0), left_context_len
+ )
+ processed_lens = states[-1] # (batch,)
+ # (batch, left_context_size)
+ processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
+ # Update processed lengths
+ new_processed_lens = processed_lens + x_lens
+
+ # (batch, left_context_size + chunk_size)
+ src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
+
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ encoder_states = states[:-2]
+
+ (
+ encoder_out,
+ encoder_out_lens,
+ new_encoder_states,
+ ) = self.encoder.streaming_forward(
+ x=x,
+ x_lens=x_lens,
+ states=encoder_states,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ new_states = new_encoder_states + [
+ new_cached_embed_left_pad,
+ new_processed_lens,
+ ]
+ return encoder_out, encoder_out_lens, new_states
+
+ @torch.jit.export
+ def get_init_states(
+ self,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+ ) -> List[torch.Tensor]:
+ """
+ Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ states[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+ """
+ states = self.encoder.get_init_states(batch_size, device)
+
+ embed_states = self.encoder_embed.get_init_states(batch_size, device)
+ states.append(embed_states)
+
+ processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
+ states.append(processed_lens)
+
+ return states
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ # if torch.cuda.is_available():
+ # device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ # merge the LoRA weights
+ model.eval()
+
+ params.use_lora = False
+ base_model = get_model(params)
+
+ new_state_dict = {}
+ state_dict = model.state_dict()
+ param_names = base_model.state_dict().keys()
+ for k in param_names:
+ assert k in state_dict.keys()
+ new_state_dict[k] = state_dict[k]
+
+ base_model.load_state_dict(new_state_dict, strict=True)
+
+ model = base_model
+ model.eval()
+
+ if params.jit is True:
+ convert_scaled_to_non_scaled(model, inplace=True)
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+
+ # Wrap encoder and encoder_embed as a module
+ if params.causal:
+ model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
+ chunk_size = model.encoder.chunk_size
+ left_context_len = model.encoder.left_context_len
+ filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
+ else:
+ model.encoder = EncoderModel(model.encoder, model.encoder_embed)
+ filename = "jit_script.pt"
+
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ model.save(str(params.exp_dir / filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torchscript. Export model.state_dict()")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py
new file mode 100755
index 000000000..0464cf65c
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/finetune.py
@@ -0,0 +1,1553 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey,
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+# Fine-tune without mux (i.e not mixing with original training data):
+./zipformer/finetune.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --do-finetune 1 \
+ --finetune-ckpt path/to/ckpt \
+ --base-lr 0.0045 \
+ --use-mux 0 \
+ --exp-dir zipformer/exp_finetune \
+ --max-duration 1000
+
+# Fine-tune without mux (i.e mixing with original training data):
+./zipformer/finetune.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --do-finetune 1 \
+ --finetune-ckpt path/to/ckpt \
+ --base-lr 0.0045 \
+ --use-mux 1 \
+ --exp-dir zipformer/exp_finetune \
+ --max-duration 1000
+
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut, CutSet
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ # Note that we add a very large constant here to make the ScheduledFloat
+ # variable as their end value.
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ ) + 100000
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_finetune_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--do-finetune",
+ type=str2bool,
+ default=True,
+ help="If true, finetune from a pre-trained checkpoint",
+ )
+
+ parser.add_argument(
+ "--use-mux",
+ type=str2bool,
+ default=False,
+ help="""
+ Whether to adapt. If true, we will mix 5% of the new data
+ with 95% of the original data to fine-tune. This is useful
+ if you want to maintain the performance on the original domain
+ """,
+ )
+
+ parser.add_argument(
+ "--use-lora", type=str2bool, default=True, help="If use LoRA for fine-tune"
+ )
+
+ parser.add_argument(
+ "--lora-r", type=int, default=0, help="The bottleneck dimension of LoRA"
+ )
+
+ parser.add_argument(
+ "--init-modules",
+ type=str,
+ default=None,
+ help="""
+ Modules to be initialized. It matches all parameters starting with
+ a specific key. The keys are given with Comma seperated. If None,
+ all modules will be initialised. For example, if you only want to
+ initialise all parameters staring with "encoder", use "encoder";
+ if you want to initialise parameters starting with encoder or decoder,
+ use "encoder,joiner".
+ """,
+ )
+
+ parser.add_argument(
+ "--finetune-ckpt",
+ type=str,
+ default=None,
+ help="Fine-tuning from which checkpoint (path to a .pt file)",
+ )
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr",
+ type=float,
+ default=0.045,
+ help="""The base learning rate.
+ It is set to a very small value as we are doing fine-tuning""",
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=100000.0,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. It is set to a very large value here to prevent the lr from decaying too fast
+ during fine-tuning.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=100.0,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ It is set to a very large value here to prevent the lr from decaying too fast
+ during fine-tuning.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+ add_finetune_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ use_lora=params.use_lora,
+ lora_r=params.lora_r if params.use_lora else 0,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def load_model_params(
+ ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True
+):
+ """Load model params from checkpoint
+
+ Args:
+ ckpt (str): Path to the checkpoint
+ model (nn.Module): model to be loaded
+ init_modules (list[str]): List of modules to be initialized
+
+ """
+ logging.info(f"Loading checkpoint from {ckpt}")
+ checkpoint = torch.load(ckpt, map_location="cpu")
+
+ # if module list is empty, load the whole model from ckpt
+ if not init_modules:
+ if next(iter(checkpoint["model"])).startswith("module."):
+ logging.info("Loading checkpoint saved by DDP")
+
+ dst_state_dict = model.state_dict()
+ src_state_dict = checkpoint["model"]
+ for key in dst_state_dict.keys():
+ src_key = "{}.{}".format("module", key)
+ dst_state_dict[key] = src_state_dict.pop(src_key)
+ assert len(src_state_dict) == 0
+ model.load_state_dict(dst_state_dict, strict=strict)
+ else:
+ model.load_state_dict(checkpoint["model"], strict=strict)
+ else:
+ src_state_dict = checkpoint["model"]
+ dst_state_dict = model.state_dict()
+ for module in init_modules:
+ logging.info(f"Loading parameters starting with prefix {module}")
+ src_keys = [
+ k for k in src_state_dict.keys() if k.startswith(module.strip() + ".")
+ ]
+ dst_keys = [
+ k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".")
+ ]
+ assert set(src_keys) == set(dst_keys) # two sets should match exactly
+ for key in src_keys:
+ dst_state_dict[key] = src_state_dict.pop(key)
+
+ model.load_state_dict(dst_state_dict, strict=strict)
+
+ return None
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dls: torch.utils.data.DataLoader,
+ valid_sets: List[str],
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ for name, m in model.named_modules():
+ if "lora" in name:
+ m.training = True
+ else:
+ m.training = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ for valid_set, valid_dl in zip(valid_sets, valid_dls):
+ logging.info(f"Computing validation loss on {valid_set}")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ logging.info(
+ f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}"
+ )
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train
+ )
+ model.train()
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ # load model parameters for model fine-tuning
+ if params.do_finetune:
+ assert params.start_epoch == 1, "Fine-tune must start from epoch 1"
+ modules = params.init_modules.split(",") if params.init_modules else None
+ checkpoints = load_model_params(
+ ckpt=params.finetune_ckpt, model=model, init_modules=modules, strict=False
+ )
+ # Need to update the model_avg if use initialisation
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+ else:
+ # resuming training
+ assert params.start_epoch > 1, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ # keep the original model untouched, only update the adapters
+ num_trainable = 0
+ for name, p in model.named_parameters():
+ if "lora_A" in name or "lora_B" in name:
+ p.requires_grad = True
+ num_trainable += p.numel()
+ else:
+ p.requires_grad = False
+
+ logging.info(
+ "A total of {} trainable parameters ({:.3f}% of the whole model)".format(
+ num_trainable, num_trainable / num_param * 100
+ )
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ gigaspeech_cuts = librispeech.gigaspeech_subset_small_cuts()
+ if params.use_mux:
+ librispeech_cuts = librispeech.train_all_shuf_cuts()
+ train_cuts = CutSet.mux(
+ gigaspeech_cuts, # num cuts = 688182
+ librispeech_cuts, # num cuts = 843723
+ weights=[688182, 843723],
+ stop_early=True,
+ )
+ else:
+ train_cuts = gigaspeech_cuts
+ logging.info(train_cuts)
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ valid_cuts += librispeech.dev_other_cuts()
+ gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
+
+ valid_sets = ["librispeech", "gigaspeech"]
+ valid_dls = [
+ librispeech.valid_dataloaders(valid_cuts),
+ librispeech.valid_dataloaders(gigaspeech_dev_cuts),
+ ]
+
+ # if not params.print_diagnostics:
+ # scan_pessimistic_batches_for_oom(
+ # model=model,
+ # train_dl=train_dl,
+ # optimizer=optimizer,
+ # sp=sp,
+ # params=params,
+ # )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dls=valid_dls,
+ valid_sets=valid_sets,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/ASR/zipformer_lora/joiner.py b/egs/librispeech/ASR/zipformer_lora/joiner.py
new file mode 120000
index 000000000..444cb5f15
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/joiner.py
@@ -0,0 +1 @@
+../zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/model.py b/egs/librispeech/ASR/zipformer_lora/model.py
new file mode 120000
index 000000000..0c6fe6112
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/model.py
@@ -0,0 +1 @@
+../zipformer/model.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/optim.py b/egs/librispeech/ASR/zipformer_lora/optim.py
new file mode 120000
index 000000000..207eecfcd
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/optim.py
@@ -0,0 +1 @@
+../zipformer/optim.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py
new file mode 100644
index 000000000..3149db9f3
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/scaling.py
@@ -0,0 +1,2052 @@
+# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import math
+import random
+from typing import Optional, Tuple, Union
+
+import k2
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+
+def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
+ max_value = torch.max(x, y)
+ diff = torch.abs(x - y)
+ return max_value + torch.log1p(torch.exp(-diff))
+
+
+# RuntimeError: Exporting the operator logaddexp to ONNX opset version
+# 14 is not supported. Please feel free to request support or submit
+# a pull request on PyTorch GitHub.
+#
+# The following function is to solve the above error when exporting
+# models to ONNX via torch.jit.trace()
+def logaddexp(x: Tensor, y: Tensor) -> Tensor:
+ # Caution(fangjun): Put torch.jit.is_scripting() before
+ # torch.onnx.is_in_onnx_export();
+ # otherwise, it will cause errors for torch.jit.script().
+ #
+ # torch.logaddexp() works for both torch.jit.script() and
+ # torch.jit.trace() but it causes errors for ONNX export.
+ #
+ if torch.jit.is_scripting():
+ # Note: We cannot use torch.jit.is_tracing() here as it also
+ # matches torch.onnx.export().
+ return torch.logaddexp(x, y)
+ elif torch.onnx.is_in_onnx_export():
+ return logaddexp_onnx(x, y)
+ else:
+ # for torch.jit.trace()
+ return torch.logaddexp(x, y)
+
+
+class PiecewiseLinear(object):
+ """
+ Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
+ the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y]
+ respectively.
+ """
+
+ def __init__(self, *args):
+ assert len(args) >= 1, len(args)
+ if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
+ self.pairs = list(args[0].pairs)
+ else:
+ self.pairs = [(float(x), float(y)) for x, y in args]
+ for x, y in self.pairs:
+ assert isinstance(x, (float, int)), type(x)
+ assert isinstance(y, (float, int)), type(y)
+
+ for i in range(len(self.pairs) - 1):
+ assert self.pairs[i + 1][0] > self.pairs[i][0], (
+ i,
+ self.pairs[i],
+ self.pairs[i + 1],
+ )
+
+ def __str__(self):
+ # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
+ return f"PiecewiseLinear({str(self.pairs)[1:-1]})"
+
+ def __call__(self, x):
+ if x <= self.pairs[0][0]:
+ return self.pairs[0][1]
+ elif x >= self.pairs[-1][0]:
+ return self.pairs[-1][1]
+ else:
+ cur_x, cur_y = self.pairs[0]
+ for i in range(1, len(self.pairs)):
+ next_x, next_y = self.pairs[i]
+ if x >= cur_x and x <= next_x:
+ return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
+ cur_x, cur_y = next_x, next_y
+ assert False
+
+ def __mul__(self, alpha):
+ return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
+
+ def __add__(self, x):
+ if isinstance(x, (float, int)):
+ return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
+ s, x = self.get_common_basis(x)
+ return PiecewiseLinear(
+ *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]
+ )
+
+ def max(self, x):
+ if isinstance(x, (float, int)):
+ x = PiecewiseLinear((0, x))
+ s, x = self.get_common_basis(x, include_crossings=True)
+ return PiecewiseLinear(
+ *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
+ )
+
+ def min(self, x):
+ if isinstance(x, float) or isinstance(x, int):
+ x = PiecewiseLinear((0, x))
+ s, x = self.get_common_basis(x, include_crossings=True)
+ return PiecewiseLinear(
+ *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
+ )
+
+ def __eq__(self, other):
+ return self.pairs == other.pairs
+
+ def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False):
+ """
+ Returns (self_mod, p_mod) which are equivalent piecewise linear
+ functions to self and p, but with the same x values.
+
+ p: the other piecewise linear function
+ include_crossings: if true, include in the x values positions
+ where the functions indicate by this and p crosss.
+ """
+ assert isinstance(p, PiecewiseLinear), type(p)
+
+ # get sorted x-values without repetition.
+ x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
+ y_vals1 = [self(x) for x in x_vals]
+ y_vals2 = [p(x) for x in x_vals]
+
+ if include_crossings:
+ extra_x_vals = []
+ for i in range(len(x_vals) - 1):
+ if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]):
+ # if the two lines in this subsegment potentially cross each other..
+ diff_cur = abs(y_vals1[i] - y_vals2[i])
+ diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
+ # `pos`, between 0 and 1, gives the relative x position,
+ # with 0 being x_vals[i] and 1 being x_vals[i+1].
+ pos = diff_cur / (diff_cur + diff_next)
+ extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
+ extra_x_vals.append(extra_x_val)
+ if len(extra_x_vals) > 0:
+ x_vals = sorted(set(x_vals + extra_x_vals))
+ y_vals1 = [self(x) for x in x_vals]
+ y_vals2 = [p(x) for x in x_vals]
+ return (
+ PiecewiseLinear(*zip(x_vals, y_vals1)),
+ PiecewiseLinear(*zip(x_vals, y_vals2)),
+ )
+
+
+class ScheduledFloat(torch.nn.Module):
+ """
+ This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
+ it does not have a working forward() function. You are supposed to cast it to float, as
+ in, float(parent_module.whatever), and use it as something like a dropout prob.
+
+ It is a floating point value whose value changes depending on the batch count of the
+ training loop. It is a piecewise linear function where you specify the (x,y) pairs
+ in sorted order on x; x corresponds to the batch index. For batch-index values before the
+ first x or after the last x, we just use the first or last y value.
+
+ Example:
+ self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
+
+ `default` is used when self.batch_count is not set or not in training mode or in
+ torch.jit scripting mode.
+ """
+
+ def __init__(self, *args, default: float = 0.0):
+ super().__init__()
+ # self.batch_count and self.name will be written to in the training loop.
+ self.batch_count = None
+ self.name = None
+ self.default = default
+ self.schedule = PiecewiseLinear(*args)
+
+ def extra_repr(self) -> str:
+ return (
+ f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
+ )
+
+ def __float__(self):
+ batch_count = self.batch_count
+ if (
+ batch_count is None
+ or not self.training
+ or torch.jit.is_scripting()
+ or torch.jit.is_tracing()
+ ):
+ return float(self.default)
+ else:
+ ans = self.schedule(self.batch_count)
+ if random.random() < 0.0002:
+ logging.info(
+ f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}"
+ )
+ return ans
+
+ def __add__(self, x):
+ if isinstance(x, float) or isinstance(x, int):
+ return ScheduledFloat(self.schedule + x, default=self.default)
+ else:
+ return ScheduledFloat(
+ self.schedule + x.schedule, default=self.default + x.default
+ )
+
+ def max(self, x):
+ if isinstance(x, float) or isinstance(x, int):
+ return ScheduledFloat(self.schedule.max(x), default=self.default)
+ else:
+ return ScheduledFloat(
+ self.schedule.max(x.schedule), default=max(self.default, x.default)
+ )
+
+
+FloatLike = Union[float, ScheduledFloat]
+
+
+def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
+ """
+ A randomized way of casting a floating point value to half precision.
+ """
+ if x.dtype == torch.float16:
+ return x
+ x_abs = x.abs()
+ is_too_small = x_abs < min_abs
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
+ # for those elements].
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
+
+
+class CutoffEstimator:
+ """
+ Estimates cutoffs of an arbitrary numerical quantity such that a specified
+ proportion of items will be above the cutoff on average.
+
+ p is the proportion of items that should be above the cutoff.
+ """
+
+ def __init__(self, p: float):
+ self.p = p
+ # total count of items
+ self.count = 0
+ # total count of items that were above the cutoff
+ self.count_above = 0
+ # initial cutoff value
+ self.cutoff = 0
+
+ def __call__(self, x: float) -> bool:
+ """
+ Returns true if x is above the cutoff.
+ """
+ ans = x > self.cutoff
+ self.count += 1
+ if ans:
+ self.count_above += 1
+ cur_p = self.count_above / self.count
+ delta_p = cur_p - self.p
+ if (delta_p > 0) == ans:
+ q = abs(delta_p)
+ self.cutoff = x * q + self.cutoff * (1 - q)
+ return ans
+
+
+class SoftmaxFunction(torch.autograd.Function):
+ """
+ Tries to handle half-precision derivatives in a randomized way that should
+ be more accurate for training than the default behavior.
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor, dim: int):
+ ans = x.softmax(dim=dim)
+ # if x dtype is float16, x.softmax() returns a float32 because
+ # (presumably) that op does not support float16, and autocast
+ # is enabled.
+ if torch.is_autocast_enabled():
+ ans = ans.to(torch.float16)
+ ctx.save_for_backward(ans)
+ ctx.x_dtype = x.dtype
+ ctx.dim = dim
+ return ans
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor):
+ (ans,) = ctx.saved_tensors
+ with torch.cuda.amp.autocast(enabled=False):
+ ans_grad = ans_grad.to(torch.float32)
+ ans = ans.to(torch.float32)
+ x_grad = ans_grad * ans
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
+ return x_grad, None
+
+
+def softmax(x: Tensor, dim: int):
+ if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x.softmax(dim=dim)
+
+ return SoftmaxFunction.apply(x, dim)
+
+
+class MaxEigLimiterFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ coeffs: Tensor,
+ direction: Tensor,
+ channel_dim: int,
+ grad_scale: float,
+ ) -> Tensor:
+ ctx.channel_dim = channel_dim
+ ctx.grad_scale = grad_scale
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad, *args):
+ with torch.enable_grad():
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
+ x_orig.requires_grad = True
+ num_channels = x_orig.shape[ctx.channel_dim]
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
+ new_direction.requires_grad = False
+ x = x - x.mean(dim=0)
+ x_var = (x**2).mean()
+ x_residual = x - coeffs * new_direction
+ x_residual_var = (x_residual**2).mean()
+ # `variance_proportion` is the proportion of the variance accounted for
+ # by the top eigen-direction. This is to be minimized.
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
+ variance_proportion.backward()
+ x_orig_grad = x_orig.grad
+ x_extra_grad = (
+ x_orig.grad
+ * ctx.grad_scale
+ * x_grad.norm()
+ / (x_orig_grad.norm() + 1.0e-20)
+ )
+ return x_grad + x_extra_grad.detach(), None, None, None, None
+
+
+class BiasNormFunction(torch.autograd.Function):
+ # This computes:
+ # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
+ # return x * scales
+ # (after unsqueezing the bias), but it does it in a memory-efficient way so that
+ # it can just store the returned value (chances are, this will also be needed for
+ # some other reason, related to the next operation, so we can save memory).
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ bias: Tensor,
+ log_scale: Tensor,
+ channel_dim: int,
+ store_output_for_backprop: bool,
+ ) -> Tensor:
+ assert bias.ndim == 1
+ if channel_dim < 0:
+ channel_dim = channel_dim + x.ndim
+ ctx.store_output_for_backprop = store_output_for_backprop
+ ctx.channel_dim = channel_dim
+ for _ in range(channel_dim + 1, x.ndim):
+ bias = bias.unsqueeze(-1)
+ scales = (
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
+ ) * log_scale.exp()
+ ans = x * scales
+ ctx.save_for_backward(
+ ans.detach() if store_output_for_backprop else x,
+ scales.detach(),
+ bias.detach(),
+ log_scale.detach(),
+ )
+ return ans
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor) -> Tensor:
+ ans_or_x, scales, bias, log_scale = ctx.saved_tensors
+ if ctx.store_output_for_backprop:
+ x = ans_or_x / scales
+ else:
+ x = ans_or_x
+ x = x.detach()
+ x.requires_grad = True
+ bias.requires_grad = True
+ log_scale.requires_grad = True
+ with torch.enable_grad():
+ # recompute scales from x, bias and log_scale.
+ scales = (
+ torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5
+ ) * log_scale.exp()
+ ans = x * scales
+ ans.backward(gradient=ans_grad)
+ return x.grad, bias.grad.flatten(), log_scale.grad, None, None
+
+
+class BiasNorm(torch.nn.Module):
+ """
+ This is intended to be a simpler, and hopefully cheaper, replacement for
+ LayerNorm. The observation this is based on, is that Transformer-type
+ networks, especially with pre-norm, sometimes seem to set one of the
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
+ the LayerNorm because the output magnitude is then not strongly dependent
+ on the other (useful) features. Presumably the weight and bias of the
+ LayerNorm are required to allow it to do this.
+
+ Instead, we give the BiasNorm a trainable bias that it can use when
+ computing the scale for normalization. We also give it a (scalar)
+ trainable scale on the output.
+
+
+ Args:
+ num_channels: the number of channels, e.g. 512.
+ channel_dim: the axis/dimension corresponding to the channel,
+ interpreted as an offset from the input's ndim if negative.
+ This is NOT the num_channels; it should typically be one of
+ {-2, -1, 0, 1, 2, 3}.
+ log_scale: the initial log-scale that we multiply the output by; this
+ is learnable.
+ log_scale_min: FloatLike, minimum allowed value of log_scale
+ log_scale_max: FloatLike, maximum allowed value of log_scale
+ store_output_for_backprop: only possibly affects memory use; recommend
+ to set to True if you think the output of this module is more likely
+ than the input of this module to be required to be stored for the
+ backprop.
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ channel_dim: int = -1, # CAUTION: see documentation.
+ log_scale: float = 1.0,
+ log_scale_min: float = -1.5,
+ log_scale_max: float = 1.5,
+ store_output_for_backprop: bool = False,
+ ) -> None:
+ super(BiasNorm, self).__init__()
+ self.num_channels = num_channels
+ self.channel_dim = channel_dim
+ self.log_scale = nn.Parameter(torch.tensor(log_scale))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+
+ self.log_scale_min = log_scale_min
+ self.log_scale_max = log_scale_max
+
+ self.store_output_for_backprop = store_output_for_backprop
+
+ def forward(self, x: Tensor) -> Tensor:
+ assert x.shape[self.channel_dim] == self.num_channels
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ channel_dim = self.channel_dim
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ bias = self.bias
+ for _ in range(channel_dim + 1, x.ndim):
+ bias = bias.unsqueeze(-1)
+ scales = (
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
+ ) * self.log_scale.exp()
+ return x * scales
+
+ log_scale = limit_param_value(
+ self.log_scale,
+ min=float(self.log_scale_min),
+ max=float(self.log_scale_max),
+ training=self.training,
+ )
+
+ return BiasNormFunction.apply(
+ x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop
+ )
+
+
+def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
+ """
+ Behaves like a constructor of a modified version of nn.Linear
+ that gives an easy way to set the default initial parameter scale.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+ ans = nn.Linear(*args, **kwargs)
+ with torch.no_grad():
+ ans.weight[:] *= initial_scale
+ if ans.bias is not None:
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
+ return ans
+
+
+class LoRALayer:
+ def __init__(
+ self,
+ r: int,
+ lora_alpha: int,
+ lora_dropout: float,
+ merge_weights: bool,
+ ):
+ self.r = r
+ self.lora_alpha = lora_alpha
+ # Optional dropout
+ if lora_dropout > 0.0:
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
+ else:
+ self.lora_dropout = lambda x: x
+ # Mark the weight as unmerged
+ self.merged = False
+ self.merge_weights = merge_weights
+
+
+class ScaledLinear_lora(nn.Linear, LoRALayer):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ r: int = 0,
+ fan_in_fan_out: bool = False,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.0,
+ initial_scale: float = 1.0,
+ merge_weights: bool = True,
+ **kwargs,
+ ):
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
+ LoRALayer.__init__(
+ self,
+ r=r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ merge_weights=merge_weights,
+ )
+
+ self.initial_scale = initial_scale
+ self.fan_in_fan_out = fan_in_fan_out
+ if r > 0:
+ self.lora_A = nn.Parameter(torch.full((r, in_features), 0.0))
+ self.lora_B = nn.Parameter(torch.full((out_features, r), 0.0))
+ self.scaling = self.lora_alpha / self.r
+ self.weight.requires_grad = False
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ # initialize the parameters
+ nn.Linear.reset_parameters(self)
+ if hasattr(self, "lora_A"):
+ initial_scale = self.initial_scale
+ with torch.no_grad():
+ self.weight[:] *= initial_scale
+ if self.bias is not None:
+ nn.init.uniform_(
+ self.bias, -0.1 * initial_scale, 0.1 * initial_scale
+ )
+ if hasattr(self, "lora_A"):
+ # initialize B the same way as the default for nn.Linear and A to zero
+ # this is different than what is described in the paper but should not affect performance
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ def train(self, mode: bool = True):
+ def T(w):
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
+
+ nn.Linear.train(self, mode)
+ if mode:
+ # We don't want the weights to be merged in training mode
+ if self.merge_weights and self.merged:
+ if self.r > 0:
+ self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+ self.merged = False
+ else:
+ # When evaluating the model, we merge the weights for simplicity
+ if self.merge_weights and not self.merged:
+ # Merge the weights and mark it
+ if self.r > 0:
+ self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+ self.merged = True
+
+ def forward(self, x: torch.Tensor):
+ def T(w):
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
+
+ if self.r > 0 and not self.merged:
+ result = F.linear(x, T(self.weight), bias=self.bias)
+ delta_result = (
+ self.lora_dropout(x)
+ @ self.lora_A.transpose(0, 1)
+ @ self.lora_B.transpose(0, 1)
+ )
+ return result + delta_result * self.scaling
+ else:
+ return F.linear(x, T(self.weight), bias=self.bias)
+
+
+def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d:
+ """
+ Behaves like a constructor of a modified version of nn.Conv1d
+ that gives an easy way to set the default initial parameter scale.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+ ans = nn.Conv1d(*args, **kwargs)
+ with torch.no_grad():
+ ans.weight[:] *= initial_scale
+ if ans.bias is not None:
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
+ return ans
+
+
+def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d:
+ """
+ Behaves like a constructor of a modified version of nn.Conv2d
+ that gives an easy way to set the default initial parameter scale.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False, but:
+ NO PADDING-RELATED ARGS.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+ ans = nn.Conv2d(*args, **kwargs)
+ with torch.no_grad():
+ ans.weight[:] *= initial_scale
+ if ans.bias is not None:
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
+ return ans
+
+
+class ChunkCausalDepthwiseConv1d(torch.nn.Module):
+ """
+ Behaves like a depthwise 1d convolution, except that it is causal in
+ a chunkwise way, as if we had a block-triangular attention mask.
+ The chunk size is provided at test time (it should probably be
+ kept in sync with the attention mask).
+
+ This has a little more than twice the parameters of a conventional
+ depthwise conv1d module: we implement it by having one
+ depthwise convolution, of half the width, that is causal (via
+ right-padding); and one depthwise convolution that is applied only
+ within chunks, that we multiply by a scaling factor which depends
+ on the position within the chunk.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ initial_scale: float = 1.0,
+ bias: bool = True,
+ ):
+ super().__init__()
+ assert kernel_size % 2 == 1
+
+ half_kernel_size = (kernel_size + 1) // 2
+ # will pad manually, on one side.
+ self.causal_conv = nn.Conv1d(
+ in_channels=channels,
+ out_channels=channels,
+ groups=channels,
+ kernel_size=half_kernel_size,
+ padding=0,
+ bias=True,
+ )
+
+ self.chunkwise_conv = nn.Conv1d(
+ in_channels=channels,
+ out_channels=channels,
+ groups=channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ bias=bias,
+ )
+
+ # first row is correction factors added to the scale near the left edge of the chunk,
+ # second row is correction factors added to the scale near the right edge of the chunk,
+ # both of these are added to a default scale of 1.0.
+ self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size))
+ self.kernel_size = kernel_size
+
+ with torch.no_grad():
+ self.causal_conv.weight[:] *= initial_scale
+ self.chunkwise_conv.weight[:] *= initial_scale
+ if bias:
+ torch.nn.init.uniform_(
+ self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale
+ )
+
+ def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
+ """
+ Forward function. Args:
+ x: a Tensor of shape (batch_size, channels, seq_len)
+ chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
+ """
+ (batch_size, num_channels, seq_len) = x.shape
+
+ # half_kernel_size = self.kernel_size + 1 // 2
+ # left_pad is half_kernel_size - 1 where half_kernel_size is the size used
+ # in the causal conv. It's the amount by which we must pad on the left,
+ # to make the convolution causal.
+ left_pad = self.kernel_size // 2
+
+ if chunk_size < 0 or chunk_size > seq_len:
+ chunk_size = seq_len
+ right_pad = -seq_len % chunk_size
+
+ x = torch.nn.functional.pad(x, (left_pad, right_pad))
+
+ x_causal = self.causal_conv(x[..., : left_pad + seq_len])
+ assert x_causal.shape == (batch_size, num_channels, seq_len)
+
+ x_chunk = x[..., left_pad:]
+ num_chunks = x_chunk.shape[2] // chunk_size
+ x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size)
+ x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(
+ batch_size * num_chunks, num_channels, chunk_size
+ )
+ x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
+
+ chunk_scale = self._get_chunk_scale(chunk_size)
+
+ x_chunk = x_chunk * chunk_scale
+ x_chunk = x_chunk.reshape(
+ batch_size, num_chunks, num_channels, chunk_size
+ ).permute(0, 2, 1, 3)
+ x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[
+ ..., :seq_len
+ ]
+
+ return x_chunk + x_causal
+
+ def _get_chunk_scale(self, chunk_size: int):
+ """Returns tensor of shape (num_channels, chunk_size) that will be used to
+ scale the output of self.chunkwise_conv."""
+ left_edge = self.chunkwise_conv_scale[0]
+ right_edge = self.chunkwise_conv_scale[1]
+ if chunk_size < self.kernel_size:
+ left_edge = left_edge[:, :chunk_size]
+ right_edge = right_edge[:, -chunk_size:]
+ else:
+ t = chunk_size - self.kernel_size
+ channels = left_edge.shape[0]
+ pad = torch.zeros(
+ channels, t, device=left_edge.device, dtype=left_edge.dtype
+ )
+ left_edge = torch.cat((left_edge, pad), dim=-1)
+ right_edge = torch.cat((pad, right_edge), dim=-1)
+ return 1.0 + (left_edge + right_edge)
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ cache: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """Streaming Forward function.
+
+ Args:
+ x: a Tensor of shape (batch_size, channels, seq_len)
+ cache: cached left context of shape (batch_size, channels, left_pad)
+ """
+ (batch_size, num_channels, seq_len) = x.shape
+
+ # left_pad is half_kernel_size - 1 where half_kernel_size is the size used
+ # in the causal conv. It's the amount by which we must pad on the left,
+ # to make the convolution causal.
+ left_pad = self.kernel_size // 2
+
+ # Pad cache
+ assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad)
+ x = torch.cat([cache, x], dim=2)
+ # Update cache
+ cache = x[..., -left_pad:]
+
+ x_causal = self.causal_conv(x)
+ assert x_causal.shape == (batch_size, num_channels, seq_len)
+
+ x_chunk = x[..., left_pad:]
+ x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
+
+ chunk_scale = self._get_chunk_scale(chunk_size=seq_len)
+ x_chunk = x_chunk * chunk_scale
+
+ return x_chunk + x_causal, cache
+
+
+class BalancerFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ min_mean: float,
+ max_mean: float,
+ min_rms: float,
+ max_rms: float,
+ grad_scale: float,
+ channel_dim: int,
+ ) -> Tensor:
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ ctx.channel_dim = channel_dim
+ ctx.save_for_backward(x)
+ ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim)
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
+ (x,) = ctx.saved_tensors
+ (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
+
+ try:
+ with torch.enable_grad():
+ with torch.cuda.amp.autocast(enabled=False):
+ x = x.to(torch.float32)
+ x = x.detach()
+ x.requires_grad = True
+ mean_dims = [i for i in range(x.ndim) if i != channel_dim]
+ uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True)
+ mean = x.mean(dim=mean_dims, keepdim=True)
+ stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
+ rms = uncentered_var.clamp(min=1.0e-20).sqrt()
+
+ m = mean / stddev
+ # part of loss that relates to mean / stddev
+ m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
+
+ # put a much larger scale on the RMS-max-limit loss, so that if both it and the
+ # m_loss are violated we fix the RMS loss first.
+ rms_clamped = rms.clamp(min=min_rms, max=max_rms)
+ r_loss = (rms_clamped / rms).log().abs()
+
+ loss = m_loss + r_loss
+
+ loss.backward(gradient=torch.ones_like(loss))
+ loss_grad = x.grad
+ loss_grad_rms = (
+ (loss_grad**2)
+ .mean(dim=mean_dims, keepdim=True)
+ .sqrt()
+ .clamp(min=1.0e-20)
+ )
+
+ loss_grad = loss_grad * (grad_scale / loss_grad_rms)
+
+ x_grad_float = x_grad.to(torch.float32)
+ # scale each element of loss_grad by the absolute value of the corresponding
+ # element of x_grad, which we view as a noisy estimate of its magnitude for that
+ # (frame and dimension). later we can consider factored versions.
+ x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
+ x_grad = x_grad_mod.to(x_grad.dtype)
+ except Exception as e:
+ logging.info(
+ f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue."
+ )
+
+ return x_grad, None, None, None, None, None, None
+
+
+class Balancer(torch.nn.Module):
+ """
+ Modifies the backpropped derivatives of a function to try to encourage, for
+ each channel, that it is positive at least a proportion `threshold` of the
+ time. It does this by multiplying negative derivative values by up to
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
+ interpolated from 1 at the threshold to those extremal values when none
+ of the inputs are positive.
+
+ Args:
+ num_channels: the number of channels
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
+ min_positive: the minimum, per channel, of the proportion of the time
+ that (x > 0), below which we start to modify the derivatives.
+ max_positive: the maximum, per channel, of the proportion of the time
+ that (x > 0), above which we start to modify the derivatives.
+ scale_gain_factor: determines the 'gain' with which we increase the
+ change in gradient once the constraints on min_abs and max_abs
+ are violated.
+ min_abs: the minimum average-absolute-value difference from the mean
+ value per channel, which we allow, before we start to modify
+ the derivatives to prevent this.
+ max_abs: the maximum average-absolute-value difference from the mean
+ value per channel, which we allow, before we start to modify
+ the derivatives to prevent this.
+ prob: determines the minimum probability with which we modify the
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
+ on each forward(). This is done randomly to prevent all layers
+ from doing it at the same time.
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ channel_dim: int,
+ min_positive: FloatLike = 0.05,
+ max_positive: FloatLike = 0.95,
+ min_abs: FloatLike = 0.2,
+ max_abs: FloatLike = 100.0,
+ grad_scale: FloatLike = 0.04,
+ prob: Optional[FloatLike] = None,
+ ):
+ super().__init__()
+
+ if prob is None:
+ prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
+ self.prob = prob
+ # 5% of the time we will return and do nothing because memory usage is
+ # too high.
+ self.mem_cutoff = CutoffEstimator(0.05)
+
+ # actually self.num_channels is no longer needed except for an assertion.
+ self.num_channels = num_channels
+ self.channel_dim = channel_dim
+ self.min_positive = min_positive
+ self.max_positive = max_positive
+ self.min_abs = min_abs
+ self.max_abs = max_abs
+ self.grad_scale = grad_scale
+
+ def forward(self, x: Tensor) -> Tensor:
+ if (
+ torch.jit.is_scripting()
+ or not x.requires_grad
+ or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))
+ ):
+ return _no_op(x)
+
+ prob = float(self.prob)
+ if random.random() < prob:
+ # The following inner-functions convert from the way we historically specified
+ # these limitations, as limits on the absolute value and the proportion of positive
+ # values, to limits on the RMS value and the (mean / stddev).
+ def _abs_to_rms(x):
+ # for normally distributed data, if the expected absolute value is x, the
+ # expected rms value will be sqrt(pi/2) * x.
+ return 1.25331413732 * x
+
+ def _proportion_positive_to_mean(x):
+ def _atanh(x):
+ eps = 1.0e-10
+ # eps is to prevent crashes if x is exactly 0 or 1.
+ # we'll just end up returning a fairly large value.
+ return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0
+
+ def _approx_inverse_erf(x):
+ # 1 / (sqrt(pi) * ln(2)),
+ # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions
+ # this approximation is extremely crude and gets progressively worse for
+ # x very close to -1 or +1, but we mostly care about the "middle" region
+ # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772,
+ # and math.erf(0.0407316414078772) = 0.045935330944660666,
+ # which is pretty close to 0.05.
+ return 0.8139535143 * _atanh(x)
+
+ # first convert x from the range 0..1 to the range -1..1 which the error
+ # function returns
+ x = -1 + (2 * x)
+ return _approx_inverse_erf(x)
+
+ min_mean = _proportion_positive_to_mean(float(self.min_positive))
+ max_mean = _proportion_positive_to_mean(float(self.max_positive))
+ min_rms = _abs_to_rms(float(self.min_abs))
+ max_rms = _abs_to_rms(float(self.max_abs))
+ grad_scale = float(self.grad_scale)
+
+ assert x.shape[self.channel_dim] == self.num_channels
+
+ return BalancerFunction.apply(
+ x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim
+ )
+ else:
+ return _no_op(x)
+
+
+def penalize_abs_values_gt(
+ x: Tensor, limit: float, penalty: float, name: str = None
+) -> Tensor:
+ """
+ Returns x unmodified, but in backprop will put a penalty for the excess of
+ the absolute values of elements of x over the limit "limit". E.g. if
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
+
+ Caution: the value of this penalty will be affected by grad scaling used
+ in automatic mixed precision training. For this reasons we use this,
+ it shouldn't really matter, or may even be helpful; we just use this
+ to disallow really implausible values of scores to be given to softmax.
+
+ The name is for randomly printed debug info.
+ """
+ x_sign = x.sign()
+ over_limit = (x.abs() - limit) > 0
+ # The following is a memory efficient way to penalize the absolute values of
+ # x that's over the limit. (The memory efficiency comes when you think
+ # about which items torch needs to cache for the autograd, and which ones it
+ # can throw away). The numerical value of aux_loss as computed here will
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
+ # limit).relu().
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
+ # sum() due to how with_loss() works.
+ x = with_loss(x, aux_loss, name)
+ # you must use x for something, or this will be ineffective.
+ return x
+
+
+def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
+ if x.ndim == 2:
+ return x.diag()
+ else:
+ (batch, dim, dim) = x.shape
+ x = x.reshape(batch, dim * dim)
+ x = x[:, :: dim + 1]
+ assert x.shape == (batch, dim)
+ return x
+
+
+def _whitening_metric(x: Tensor, num_groups: int):
+ """
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
+ of the centered feature covariance are the same within each group's covariance matrix
+ and also between groups.
+ Args:
+ x: a Tensor of shape (*, num_channels)
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
+ Returns:
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
+ greater than 1.0 otherwise.
+ """
+ assert x.dtype != torch.float16
+ x = x.reshape(-1, x.shape[-1])
+ (num_frames, num_channels) = x.shape
+ assert num_channels % num_groups == 0
+ channels_per_group = num_channels // num_groups
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
+ # x now has shape (num_groups, num_frames, channels_per_group)
+ # subtract the mean so we use the centered, not uncentered, covariance.
+ # My experience has been that when we "mess with the gradients" like this,
+ # it's better not do anything that tries to move the mean around, because
+ # that can easily cause instability.
+ x = x - x.mean(dim=1, keepdim=True)
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
+ x_covar = torch.matmul(x.transpose(1, 2), x)
+ x_covar_mean_diag = _diag(x_covar).mean()
+ # the following expression is what we'd get if we took the matrix product
+ # of each covariance and measured the mean of its trace, i.e.
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
+ x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
+ return metric
+
+
+class WhiteningPenaltyFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, module: nn.Module) -> Tensor:
+ ctx.save_for_backward(x)
+ ctx.module = module
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor):
+ (x_orig,) = ctx.saved_tensors
+ w = ctx.module
+
+ try:
+ with torch.enable_grad():
+ with torch.cuda.amp.autocast(enabled=False):
+ x_detached = x_orig.to(torch.float32).detach()
+ x_detached.requires_grad = True
+
+ metric = _whitening_metric(x_detached, w.num_groups)
+
+ if random.random() < 0.005 or __name__ == "__main__":
+ logging.info(
+ f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, "
+ f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}"
+ )
+
+ if metric < float(w.whitening_limit):
+ w.prob = w.min_prob
+ return x_grad, None
+ else:
+ w.prob = w.max_prob
+ metric.backward()
+ penalty_grad = x_detached.grad
+ scale = w.grad_scale * (
+ x_grad.to(torch.float32).norm()
+ / (penalty_grad.norm() + 1.0e-20)
+ )
+ penalty_grad = penalty_grad * scale
+ return x_grad + penalty_grad.to(x_grad.dtype), None
+ except Exception as e:
+ logging.info(
+ f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue."
+ )
+ return x_grad, None
+
+
+class Whiten(nn.Module):
+ def __init__(
+ self,
+ num_groups: int,
+ whitening_limit: FloatLike,
+ prob: Union[float, Tuple[float, float]],
+ grad_scale: FloatLike,
+ ):
+ """
+ Args:
+ num_groups: the number of groups to divide the channel dim into before
+ whitening. We will attempt to make the feature covariance
+ within each group, after mean subtraction, as "white" as possible,
+ while having the same trace across all groups.
+ whitening_limit: a value greater than 1.0, that dictates how much
+ freedom we have to violate the constraints. 1.0 would mean perfectly
+ white, with exactly the same trace across groups; larger values
+ give more freedom. E.g. 2.0.
+ prob: the probability with which we apply the gradient modification
+ (also affects the grad scale). May be supplied as a float,
+ or as a pair (min_prob, max_prob)
+
+ grad_scale: determines the scale on the gradient term from this object,
+ relative to the rest of the gradient on the attention weights.
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
+ """
+ super(Whiten, self).__init__()
+ assert num_groups >= 1
+ assert float(whitening_limit) >= 1
+ assert grad_scale >= 0
+ self.num_groups = num_groups
+ self.whitening_limit = whitening_limit
+ self.grad_scale = grad_scale
+
+ if isinstance(prob, float):
+ prob = (prob, prob)
+ (self.min_prob, self.max_prob) = prob
+ assert 0 < self.min_prob <= self.max_prob <= 1
+ self.prob = self.max_prob
+ self.name = None # will be set in training loop
+
+ def forward(self, x: Tensor) -> Tensor:
+ """
+ In the forward pass, this function just returns the input unmodified.
+ In the backward pass, it will modify the gradients to ensure that the
+ distribution in each group has close to (lambda times I) as the covariance
+ after mean subtraction, with the same lambda across groups.
+ For whitening_limit > 1, there will be more freedom to violate this
+ constraint.
+
+ Args:
+ x: the input of shape (*, num_channels)
+
+ Returns:
+ x, unmodified. You should make sure
+ you use the returned value, or the graph will be freed
+ and nothing will happen in backprop.
+ """
+ grad_scale = float(self.grad_scale)
+ if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
+ return _no_op(x)
+ else:
+ return WhiteningPenaltyFunction.apply(x, self)
+
+
+class WithLoss(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, y: Tensor, name: str):
+ ctx.y_shape = y.shape
+ if random.random() < 0.002 and name is not None:
+ loss_sum = y.sum().item()
+ logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}")
+ return x
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor):
+ return (
+ ans_grad,
+ torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
+ None,
+ )
+
+
+def with_loss(x, y, name):
+ # returns x but adds y.sum() to the loss function.
+ return WithLoss.apply(x, y, name)
+
+
+class ScaleGradFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, alpha: float) -> Tensor:
+ ctx.alpha = alpha
+ return x
+
+ @staticmethod
+ def backward(ctx, grad: Tensor):
+ return grad * ctx.alpha, None
+
+
+def scale_grad(x: Tensor, alpha: float):
+ return ScaleGradFunction.apply(x, alpha)
+
+
+class ScaleGrad(nn.Module):
+ def __init__(self, alpha: float):
+ super().__init__()
+ self.alpha = alpha
+
+ def forward(self, x: Tensor) -> Tensor:
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
+ return x
+ return scale_grad(x, self.alpha)
+
+
+class LimitParamValue(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, min: float, max: float):
+ ctx.save_for_backward(x)
+ assert max >= min
+ ctx.min = min
+ ctx.max = max
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor):
+ (x,) = ctx.saved_tensors
+ # where x < ctx.min, ensure all grads are negative (this will tend to make
+ # x more positive).
+ x_grad = x_grad * torch.where(
+ torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0
+ )
+ # where x > ctx.max, ensure all grads are positive (this will tend to make
+ # x more negative).
+ x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
+ return x_grad, None, None
+
+
+def limit_param_value(
+ x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True
+):
+ # You apply this to (typically) an nn.Parameter during training to ensure that its
+ # (elements mostly) stays within a supplied range. This is done by modifying the
+ # gradients in backprop.
+ # It's not necessary to do this on every batch: do it only some of the time,
+ # to save a little time.
+ if training and random.random() < prob:
+ return LimitParamValue.apply(x, min, max)
+ else:
+ return x
+
+
+def _no_op(x: Tensor) -> Tensor:
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x
+ else:
+ # a no-op function that will have a node in the autograd graph,
+ # to avoid certain bugs relating to backward hooks
+ return x.chunk(1, dim=-1)[0]
+
+
+class Identity(torch.nn.Module):
+ def __init__(self):
+ super(Identity, self).__init__()
+
+ def forward(self, x):
+ return _no_op(x)
+
+
+class DoubleSwishFunction(torch.autograd.Function):
+ """
+ double_swish(x) = x * torch.sigmoid(x-1)
+
+ This is a definition, originally motivated by its close numerical
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
+
+ Memory-efficient derivative computation:
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
+ Now, s'(x) = s(x) * (1-s(x)).
+ double_swish'(x) = x * s'(x) + s(x).
+ = x * s(x) * (1-s(x)) + s(x).
+ = double_swish(x) * (1-s(x)) + s(x)
+ ... so we just need to remember s(x) but not x itself.
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ requires_grad = x.requires_grad
+ if x.dtype == torch.float16:
+ x = x.to(torch.float32)
+
+ s = torch.sigmoid(x - 1.0)
+ y = x * s
+
+ if requires_grad:
+ deriv = y * (1 - s) + s
+
+ # notes on derivative of x * sigmoid(x - 1):
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
+ # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
+ # floors), should be expectation-preserving.
+ floor = -0.044
+ ceil = 1.2
+ d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
+ deriv
+ )
+ if __name__ == "__main__":
+ # for self-testing only.
+ assert d_scaled.min() >= 0.0
+ assert d_scaled.max() < 256.0
+ d_int = d_scaled.to(torch.uint8)
+ ctx.save_for_backward(d_int)
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
+ y = y.to(torch.float16)
+ return y
+
+ @staticmethod
+ def backward(ctx, y_grad: Tensor) -> Tensor:
+ (d,) = ctx.saved_tensors
+ # the same constants as used in forward pass.
+ floor = -0.043637
+ ceil = 1.2
+
+ d = d * ((ceil - floor) / 255.0) + floor
+ return y_grad * d
+
+
+class DoubleSwish(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
+ that we approximate closely with x * sigmoid(x-1).
+ """
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x * torch.sigmoid(x - 1.0)
+ return DoubleSwishFunction.apply(x)
+
+
+# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates.
+class Dropout2(nn.Module):
+ def __init__(self, p: FloatLike):
+ super().__init__()
+ self.p = p
+
+ def forward(self, x: Tensor) -> Tensor:
+ return torch.nn.functional.dropout(x, p=float(self.p), training=self.training)
+
+
+class MulForDropout3(torch.autograd.Function):
+ # returns (x * y * alpha) where alpha is a float and y doesn't require
+ # grad and is zero-or-one.
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, x, y, alpha):
+ assert not y.requires_grad
+ ans = x * y * alpha
+ ctx.save_for_backward(ans)
+ ctx.alpha = alpha
+ return ans
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, ans_grad):
+ (ans,) = ctx.saved_tensors
+ x_grad = ctx.alpha * ans_grad * (ans != 0)
+ return x_grad, None, None
+
+
+# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
+# and it lets you choose one dimension to share the dropout mask over
+class Dropout3(nn.Module):
+ def __init__(self, p: FloatLike, shared_dim: int):
+ super().__init__()
+ self.p = p
+ self.shared_dim = shared_dim
+
+ def forward(self, x: Tensor) -> Tensor:
+ p = float(self.p)
+ if not self.training or p == 0:
+ return _no_op(x)
+ scale = 1.0 / (1 - p)
+ rand_shape = list(x.shape)
+ rand_shape[self.shared_dim] = 1
+ mask = torch.rand(*rand_shape, device=x.device) > p
+ ans = MulForDropout3.apply(x, mask, scale)
+ return ans
+
+
+class SwooshLFunction(torch.autograd.Function):
+ """
+ swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ requires_grad = x.requires_grad
+ if x.dtype == torch.float16:
+ x = x.to(torch.float32)
+
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+
+ coeff = -0.08
+
+ with torch.cuda.amp.autocast(enabled=False):
+ with torch.enable_grad():
+ x = x.detach()
+ x.requires_grad = True
+ y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035
+
+ if not requires_grad:
+ return y
+
+ y.backward(gradient=torch.ones_like(y))
+
+ grad = x.grad
+ floor = coeff
+ ceil = 1.0 + coeff + 0.005
+
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
+ grad
+ )
+ if __name__ == "__main__":
+ # for self-testing only.
+ assert d_scaled.min() >= 0.0
+ assert d_scaled.max() < 256.0
+
+ d_int = d_scaled.to(torch.uint8)
+ ctx.save_for_backward(d_int)
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
+ y = y.to(torch.float16)
+ return y
+
+ @staticmethod
+ def backward(ctx, y_grad: Tensor) -> Tensor:
+ (d,) = ctx.saved_tensors
+ # the same constants as used in forward pass.
+
+ coeff = -0.08
+ floor = coeff
+ ceil = 1.0 + coeff + 0.005
+ d = d * ((ceil - floor) / 255.0) + floor
+ return y_grad * d
+
+
+class SwooshL(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-L activation."""
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
+ if not x.requires_grad:
+ return k2.swoosh_l_forward(x)
+ else:
+ return k2.swoosh_l(x)
+ # return SwooshLFunction.apply(x)
+
+
+class SwooshLOnnx(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-L activation."""
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035
+
+
+class SwooshRFunction(torch.autograd.Function):
+ """
+ swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
+
+ derivatives are between -0.08 and 0.92.
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ requires_grad = x.requires_grad
+
+ if x.dtype == torch.float16:
+ x = x.to(torch.float32)
+
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ with torch.enable_grad():
+ x = x.detach()
+ x.requires_grad = True
+ y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
+
+ if not requires_grad:
+ return y
+ y.backward(gradient=torch.ones_like(y))
+
+ grad = x.grad
+ floor = -0.08
+ ceil = 0.925
+
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
+ grad
+ )
+ if __name__ == "__main__":
+ # for self-testing only.
+ assert d_scaled.min() >= 0.0
+ assert d_scaled.max() < 256.0
+
+ d_int = d_scaled.to(torch.uint8)
+ ctx.save_for_backward(d_int)
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
+ y = y.to(torch.float16)
+ return y
+
+ @staticmethod
+ def backward(ctx, y_grad: Tensor) -> Tensor:
+ (d,) = ctx.saved_tensors
+ # the same constants as used in forward pass.
+ floor = -0.08
+ ceil = 0.925
+ d = d * ((ceil - floor) / 255.0) + floor
+ return y_grad * d
+
+
+class SwooshR(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-R activation."""
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
+ if not x.requires_grad:
+ return k2.swoosh_r_forward(x)
+ else:
+ return k2.swoosh_r(x)
+ # return SwooshRFunction.apply(x)
+
+
+class SwooshROnnx(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-R activation."""
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687
+
+
+# simple version of SwooshL that does not redefine the backprop, used in
+# ActivationDropoutAndLinearFunction.
+def SwooshLForward(x: Tensor):
+ x_offset = x - 4.0
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
+ return log_sum - 0.08 * x - 0.035
+
+
+# simple version of SwooshR that does not redefine the backprop, used in
+# ActivationDropoutAndLinearFunction.
+def SwooshRForward(x: Tensor):
+ x_offset = x - 1.0
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
+ return log_sum - 0.08 * x - 0.313261687
+
+
+class ActivationDropoutAndLinearFunction(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx,
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor],
+ activation: str,
+ dropout_p: float,
+ dropout_shared_dim: Optional[int],
+ ):
+ if dropout_p != 0.0:
+ dropout_shape = list(x.shape)
+ if dropout_shared_dim is not None:
+ dropout_shape[dropout_shared_dim] = 1
+ # else it won't be very memory efficient.
+ dropout_mask = (1.0 / (1.0 - dropout_p)) * (
+ torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p
+ )
+ else:
+ dropout_mask = None
+
+ ctx.save_for_backward(x, weight, bias, dropout_mask)
+
+ ctx.activation = activation
+
+ forward_activation_dict = {
+ "SwooshL": k2.swoosh_l_forward,
+ "SwooshR": k2.swoosh_r_forward,
+ }
+ # it will raise a KeyError if this fails. This will be an error. We let it
+ # propagate to the user.
+ activation_func = forward_activation_dict[activation]
+ x = activation_func(x)
+ if dropout_mask is not None:
+ x = x * dropout_mask
+ x = torch.nn.functional.linear(x, weight, bias)
+ return x
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, ans_grad: Tensor):
+ saved = ctx.saved_tensors
+ (x, weight, bias, dropout_mask) = saved
+
+ forward_and_deriv_activation_dict = {
+ "SwooshL": k2.swoosh_l_forward_and_deriv,
+ "SwooshR": k2.swoosh_r_forward_and_deriv,
+ }
+ # the following lines a KeyError if the activation is unrecognized.
+ # This will be an error. We let it propagate to the user.
+ func = forward_and_deriv_activation_dict[ctx.activation]
+
+ y, func_deriv = func(x)
+ if dropout_mask is not None:
+ y = y * dropout_mask
+ # now compute derivative of y w.r.t. weight and bias..
+ # y: (..., in_channels), ans_grad: (..., out_channels),
+ (out_channels, in_channels) = weight.shape
+
+ in_channels = y.shape[-1]
+ g = ans_grad.reshape(-1, out_channels)
+ weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels))
+ y_deriv = torch.matmul(ans_grad, weight)
+ bias_deriv = None if bias is None else g.sum(dim=0)
+ x_deriv = y_deriv * func_deriv
+ if dropout_mask is not None:
+ # order versus func_deriv does not matter
+ x_deriv = x_deriv * dropout_mask
+
+ return x_deriv, weight_deriv, bias_deriv, None, None, None
+
+
+class ActivationDropoutAndLinear(torch.nn.Module):
+ """
+ This merges an activation function followed by dropout and then a nn.Linear module;
+ it does so in a memory efficient way so that it only stores the input to the whole
+ module. If activation == SwooshL and dropout_shared_dim != None, this will be
+ equivalent to:
+ nn.Sequential(SwooshL(),
+ Dropout3(dropout_p, shared_dim=dropout_shared_dim),
+ ScaledLinear(in_channels, out_channels, bias=bias,
+ initial_scale=initial_scale))
+ If dropout_shared_dim is None, the dropout would be equivalent to
+ Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout
+ mask is smaller.
+
+ Args:
+ in_channels: number of input channels, e.g. 256
+ out_channels: number of output channels, e.g. 256
+ bias: if true, have a bias
+ activation: the activation function, for now just support SwooshL.
+ dropout_p: the dropout probability or schedule (happens after nonlinearity).
+ dropout_shared_dim: the dimension, if any, across which the dropout mask is
+ shared (e.g. the time dimension). If None, this may be less memory
+ efficient if there are modules before this one that cache the input
+ for their backprop (e.g. Balancer or Whiten).
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ bias: bool = True,
+ activation: str = "SwooshL",
+ dropout_p: FloatLike = 0.0,
+ dropout_shared_dim: Optional[int] = -1,
+ initial_scale: float = 1.0,
+ ):
+ super().__init__()
+ # create a temporary module of nn.Linear that we'll steal the
+ # weights and bias from
+ l = ScaledLinear(
+ in_channels, out_channels, bias=bias, initial_scale=initial_scale
+ )
+
+ self.weight = l.weight
+ # register_parameter properly handles making it a parameter when l.bias
+ # is None. I think there is some reason for doing it this way rather
+ # than just setting it to None but I don't know what it is, maybe
+ # something to do with exporting the module..
+ self.register_parameter("bias", l.bias)
+
+ self.activation = activation
+ self.dropout_p = dropout_p
+ self.dropout_shared_dim = dropout_shared_dim
+
+ def forward(self, x: Tensor):
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ if self.activation == "SwooshL":
+ x = SwooshLForward(x)
+ elif self.activation == "SwooshR":
+ x = SwooshRForward(x)
+ else:
+ assert False, self.activation
+ return torch.nn.functional.linear(x, self.weight, self.bias)
+
+ return ActivationDropoutAndLinearFunction.apply(
+ x,
+ self.weight,
+ self.bias,
+ self.activation,
+ float(self.dropout_p),
+ self.dropout_shared_dim,
+ )
+
+
+class ActivationDropoutAndLinear_lora(torch.nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ bias: bool = True,
+ activation: str = "SwooshL",
+ dropout_p: FloatLike = 0.0,
+ dropout_shared_dim: Optional[int] = -1,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.0,
+ initial_scale: float = 1.0,
+ ):
+ super().__init__()
+ self.l = ScaledLinear_lora(
+ in_features=in_channels,
+ out_features=out_channels,
+ r=r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ initial_scale=initial_scale,
+ bias=bias,
+ )
+ self.weight = self.l.weight
+ self.register_parameter("bias", self.l.bias)
+
+ if activation == "SwooshL":
+ self.activation = SwooshL()
+ elif activation == "SwooshR":
+ self.activation = SwooshR()
+ else:
+ assert False, activation
+ self.dropout = Dropout3(dropout_p, dropout_shared_dim)
+
+ def forward(self, x: Tensor):
+ return self.l(self.dropout(self.activation(x)))
+
+
+def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
+ if num_channels <= x.shape[-1]:
+ return x[..., :num_channels]
+ else:
+ shape = list(x.shape)
+ shape[-1] = num_channels - shape[-1]
+ zeros = torch.zeros(shape, dtype=x.dtype, device=x.device)
+ return torch.cat((x, zeros), dim=-1)
+
+
+def _test_whiten():
+ for proportion in [0.1, 0.5, 10.0]:
+ logging.info(f"_test_whiten(): proportion = {proportion}")
+ x = torch.randn(100, 128)
+ direction = torch.randn(128)
+ coeffs = torch.randn(100, 1)
+ x += proportion * direction * coeffs
+
+ x.requires_grad = True
+
+ m = Whiten(
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
+ ) # grad_scale
+
+ for _ in range(4):
+ y = m(x)
+
+ y_grad = torch.randn_like(x)
+ y.backward(gradient=y_grad)
+
+ if proportion < 0.2:
+ assert torch.allclose(x.grad, y_grad)
+ elif proportion > 1.0:
+ assert not torch.allclose(x.grad, y_grad)
+
+
+def _test_balancer_sign():
+ probs = torch.arange(0, 1, 0.01)
+ N = 1000
+ x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
+ x = x.detach()
+ x.requires_grad = True
+ m = Balancer(
+ probs.numel(),
+ channel_dim=0,
+ min_positive=0.05,
+ max_positive=0.95,
+ min_abs=0.0,
+ prob=1.0,
+ )
+
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
+
+ y = m(x)
+ y.backward(gradient=y_grad)
+ print("_test_balancer_sign: x = ", x)
+ print("_test_balancer_sign: y grad = ", y_grad)
+ print("_test_balancer_sign: x grad = ", x.grad)
+
+
+def _test_balancer_magnitude():
+ magnitudes = torch.arange(0, 1, 0.01)
+ N = 1000
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
+ x = x.detach()
+ x.requires_grad = True
+ m = Balancer(
+ magnitudes.numel(),
+ channel_dim=0,
+ min_positive=0.0,
+ max_positive=1.0,
+ min_abs=0.2,
+ max_abs=0.7,
+ prob=1.0,
+ )
+
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
+
+ y = m(x)
+ y.backward(gradient=y_grad)
+ print("_test_balancer_magnitude: x = ", x)
+ print("_test_balancer_magnitude: y grad = ", y_grad)
+ print("_test_balancer_magnitude: x grad = ", x.grad)
+
+
+def _test_double_swish_deriv():
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ m = DoubleSwish()
+
+ tol = (1.2 - (-0.043637)) / 255.0
+ torch.autograd.gradcheck(m, x, atol=tol)
+
+ # for self-test.
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ y = m(x)
+
+
+def _test_swooshl_deriv():
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ m = SwooshL()
+
+ tol = 1.0 / 255.0
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
+
+ # for self-test.
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ y = m(x)
+
+
+def _test_swooshr_deriv():
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ m = SwooshR()
+
+ tol = 1.0 / 255.0
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
+
+ # for self-test.
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ y = m(x)
+
+
+def _test_softmax():
+ a = torch.randn(2, 10, dtype=torch.float64)
+ b = a.clone()
+ a.requires_grad = True
+ b.requires_grad = True
+ a.softmax(dim=1)[:, 0].sum().backward()
+ print("a grad = ", a.grad)
+ softmax(b, dim=1)[:, 0].sum().backward()
+ print("b grad = ", b.grad)
+ assert torch.allclose(a.grad, b.grad)
+
+
+def _test_piecewise_linear():
+ p = PiecewiseLinear((0, 10.0))
+ for x in [-100, 0, 100]:
+ assert p(x) == 10.0
+ p = PiecewiseLinear((0, 10.0), (1, 0.0))
+ for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]:
+ print("x, y = ", x, y)
+ assert p(x) == y, (x, p(x), y)
+
+ q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
+ x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0]
+ pq = p.max(q)
+ for x in x_vals:
+ y1 = max(p(x), q(x))
+ y2 = pq(x)
+ assert abs(y1 - y2) < 0.001
+ pq = p.min(q)
+ for x in x_vals:
+ y1 = min(p(x), q(x))
+ y2 = pq(x)
+ assert abs(y1 - y2) < 0.001
+ pq = p + q
+ for x in x_vals:
+ y1 = p(x) + q(x)
+ y2 = pq(x)
+ assert abs(y1 - y2) < 0.001
+
+
+def _test_activation_dropout_and_linear():
+ in_channels = 20
+ out_channels = 30
+
+ for bias in [True, False]:
+ # actually we don't test for dropout_p != 0.0 because forward functions will give
+ # different answers. This is because we are using the k2 implementation of
+ # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn()
+ # internally, messing up the random state.
+ for dropout_p in [0.0]:
+ for activation in ["SwooshL", "SwooshR"]:
+ m1 = nn.Sequential(
+ SwooshL() if activation == "SwooshL" else SwooshR(),
+ Dropout3(p=dropout_p, shared_dim=-1),
+ ScaledLinear(
+ in_channels, out_channels, bias=bias, initial_scale=0.5
+ ),
+ )
+ m2 = ActivationDropoutAndLinear(
+ in_channels,
+ out_channels,
+ bias=bias,
+ initial_scale=0.5,
+ activation=activation,
+ dropout_p=dropout_p,
+ )
+ with torch.no_grad():
+ m2.weight[:] = m1[2].weight
+ if bias:
+ m2.bias[:] = m1[2].bias
+ # make sure forward gives same result.
+ x1 = torch.randn(10, in_channels)
+ x1.requires_grad = True
+
+ # TEMP.
+ assert torch.allclose(
+ SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03
+ )
+
+ x2 = x1.clone().detach()
+ x2.requires_grad = True
+ seed = 10
+ torch.manual_seed(seed)
+ y1 = m1(x1)
+ y_grad = torch.randn_like(y1)
+ y1.backward(gradient=y_grad)
+ torch.manual_seed(seed)
+ y2 = m2(x2)
+ y2.backward(gradient=y_grad)
+
+ print(
+ f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}"
+ )
+ print("y1 = ", y1)
+ print("y2 = ", y2)
+ assert torch.allclose(y1, y2, atol=0.02)
+ assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05)
+ if bias:
+ assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05)
+ print("x1.grad = ", x1.grad)
+ print("x2.grad = ", x2.grad)
+
+ def isclose(a, b):
+ # return true if cosine similarity is > 0.9.
+ return (a * b).sum() > 0.9 * (
+ (a**2).sum() * (b**2).sum()
+ ).sqrt()
+
+ # the SwooshL() implementation has a noisy gradient due to 1-byte
+ # storage of it.
+ assert isclose(x1.grad, x2.grad)
+
+
+if __name__ == "__main__":
+ logging.getLogger().setLevel(logging.INFO)
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ _test_piecewise_linear()
+ _test_softmax()
+ _test_whiten()
+ _test_balancer_sign()
+ _test_balancer_magnitude()
+ _test_double_swish_deriv()
+ _test_swooshr_deriv()
+ _test_swooshl_deriv()
+ _test_activation_dropout_and_linear()
diff --git a/egs/librispeech/ASR/zipformer_lora/scaling_converter.py b/egs/librispeech/ASR/zipformer_lora/scaling_converter.py
new file mode 120000
index 000000000..bc7c7b5e3
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/scaling_converter.py
@@ -0,0 +1 @@
+../zipformer/scaling_converter.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/subsampling.py b/egs/librispeech/ASR/zipformer_lora/subsampling.py
new file mode 120000
index 000000000..d178adc2e
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/subsampling.py
@@ -0,0 +1 @@
+../zipformer/subsampling.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer_lora/train.py b/egs/librispeech/ASR/zipformer_lora/train.py
new file mode 100755
index 000000000..3ccf7d2f1
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/train.py
@@ -0,0 +1,1398 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+# For non-streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --full-libri 1 \
+ --max-duration 1000
+
+# For streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --causal 1 \
+ --full-libri 1 \
+ --max-duration 1000
+
+It supports training with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriSpeechAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ librispeech = LibriSpeechAsrDataModule(args)
+
+ if params.full_libri:
+ train_cuts = librispeech.train_all_shuf_cuts()
+
+ # previously we used the following code to load all training cuts,
+ # strictly speaking, shuffled training cuts should be used instead,
+ # but we leave the code here to demonstrate that there is an option
+ # like this to combine multiple cutsets
+
+ # train_cuts = librispeech.train_clean_100_cuts()
+ # train_cuts += librispeech.train_clean_360_cuts()
+ # train_cuts += librispeech.train_other_500_cuts()
+ else:
+ train_cuts = librispeech.train_clean_100_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = librispeech.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = librispeech.dev_clean_cuts()
+ valid_cuts += librispeech.dev_other_cuts()
+ valid_dl = librispeech.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriSpeechAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/ASR/zipformer_lora/zipformer.py b/egs/librispeech/ASR/zipformer_lora/zipformer.py
new file mode 100644
index 000000000..43865609a
--- /dev/null
+++ b/egs/librispeech/ASR/zipformer_lora/zipformer.py
@@ -0,0 +1,2522 @@
+#!/usr/bin/env python3
+# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import logging
+import math
+import random
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+from encoder_interface import EncoderInterface
+from scaling import (
+ Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
+)
+from scaling import (
+ ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
+)
+from scaling import (
+ ActivationDropoutAndLinear,
+ ActivationDropoutAndLinear_lora,
+ Balancer,
+ BiasNorm,
+ ChunkCausalDepthwiseConv1d,
+ Dropout2,
+ FloatLike,
+ ScaledLinear_lora,
+ ScheduledFloat,
+ Whiten,
+ convert_num_channels,
+ limit_param_value,
+ penalize_abs_values_gt,
+ softmax,
+)
+from torch import Tensor, nn
+
+
+class Zipformer2(EncoderInterface):
+ """
+ Args:
+
+ Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length
+ as downsampling_factor if they are single ints or one-element tuples. The length of
+ downsampling_factor defines the number of stacks.
+
+ output_downsampling_factor (int): how much to downsample at the output. Note:
+ we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
+ You should probably leave this at 2.
+ downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
+ Note: this is in addition to the downsampling factor of 2 that is applied in
+ the frontend (self.encoder_embed).
+ encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
+ encoder stack.
+ num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
+ encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of
+ the encoder stacks for purposes of per-frame dropout (recommend 256 for
+ now).
+ query_head_dim (int or Tuple[int]): dimension of query and key per attention
+ head: per stack, if a tuple..
+ pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per
+ attention head
+ value_head_dim (int or Tuple[int]): dimension of value in each attention head
+ num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
+ Must be at least 4.
+ feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
+ cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
+
+ pos_dim (int): the dimension of each positional-encoding vector prior to projection,
+ e.g. 128.
+
+ dropout (float): dropout rate
+ warmup_batches (float): number of batches to warm up over; this controls
+ dropout of encoder layers.
+ causal (bool): if True, support chunkwise causal convolution. This should
+ not hurt WER as no modeling power is lost, but the convolution modules will be
+ slightly slower and use more memory. Enables use of the chunk_size and
+ left_context_chunks options in forward(), which simulates streaming
+ decoding.
+ chunk_size: (list of int): only set this to other than [-1] if causal;
+ the chunk size will be randomly chosen from this list. -1 means no chunking.
+ left_context_frames: (list of int): determines the number of left-
+ context chunks for causal training; will be rounded to a number of
+ chunks. Must not be less than cnn_module_kernel (after factoring in
+ rounding and downsampling); an error will be thrown if this is violated.
+ """
+
+ def __init__(
+ self,
+ output_downsampling_factor: int = 2,
+ downsampling_factor: Tuple[int] = (2, 4),
+ encoder_dim: Union[int, Tuple[int]] = 384,
+ num_encoder_layers: Union[int, Tuple[int]] = 4,
+ encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
+ query_head_dim: Union[int, Tuple[int]] = 24,
+ pos_head_dim: Union[int, Tuple[int]] = 4,
+ value_head_dim: Union[int, Tuple[int]] = 12,
+ num_heads: Union[int, Tuple[int]] = 8,
+ feedforward_dim: Union[int, Tuple[int]] = 1536,
+ cnn_module_kernel: Union[int, Tuple[int]] = 31,
+ pos_dim: int = 192,
+ dropout: FloatLike = None, # see code below for default
+ warmup_batches: float = 4000.0,
+ causal: bool = False,
+ chunk_size: Tuple[int] = [-1],
+ left_context_frames: Tuple[int] = [-1],
+ use_lora: bool = True,
+ lora_r: int = 0,
+ ) -> None:
+ super(Zipformer2, self).__init__()
+
+ if dropout is None:
+ dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
+
+ def _to_tuple(x):
+ """Converts a single int or a 1-tuple of an int to a tuple with the same length
+ as downsampling_factor"""
+ if isinstance(x, int):
+ x = (x,)
+ if len(x) == 1:
+ x = x * len(downsampling_factor)
+ else:
+ assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
+ return x
+
+ self.output_downsampling_factor = output_downsampling_factor # int
+ self.downsampling_factor = downsampling_factor # tuple
+ self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
+ self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(
+ encoder_unmasked_dim
+ ) # tuple
+ num_encoder_layers = _to_tuple(num_encoder_layers)
+ self.num_encoder_layers = num_encoder_layers
+ self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
+ self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
+ pos_head_dim = _to_tuple(pos_head_dim)
+ self.num_heads = num_heads = _to_tuple(num_heads)
+ feedforward_dim = _to_tuple(feedforward_dim)
+ self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
+
+ self.causal = causal
+ self.chunk_size = chunk_size
+ self.left_context_frames = left_context_frames
+
+ self.lora_r = lora_r if use_lora else 0
+
+ for u, d in zip(encoder_unmasked_dim, encoder_dim):
+ assert u <= d
+
+ # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
+ encoders = []
+
+ num_encoders = len(downsampling_factor)
+ for i in range(num_encoders):
+ encoder_layer = Zipformer2EncoderLayer(
+ embed_dim=encoder_dim[i],
+ pos_dim=pos_dim,
+ num_heads=num_heads[i],
+ query_head_dim=query_head_dim[i],
+ pos_head_dim=pos_head_dim[i],
+ value_head_dim=value_head_dim[i],
+ feedforward_dim=feedforward_dim[i],
+ dropout=dropout,
+ cnn_module_kernel=cnn_module_kernel[i],
+ causal=causal,
+ lora_r=self.lora_r,
+ )
+
+ # For the segment of the warmup period, we let the Conv2dSubsampling
+ # layer learn something. Then we start to warm up the other encoders.
+ encoder = Zipformer2Encoder(
+ encoder_layer,
+ num_encoder_layers[i],
+ pos_dim=pos_dim,
+ dropout=dropout,
+ warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
+ warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
+ final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
+ )
+
+ if downsampling_factor[i] != 1:
+ encoder = DownsampledZipformer2Encoder(
+ encoder,
+ dim=encoder_dim[i],
+ downsample=downsampling_factor[i],
+ dropout=dropout,
+ )
+
+ encoders.append(encoder)
+
+ self.encoders = nn.ModuleList(encoders)
+
+ self.downsample_output = SimpleDownsample(
+ max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout
+ )
+
+ def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
+ """
+ In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
+ randomized feature masks, one per encoder.
+ On e.g. 15% of frames, these masks will zero out all enocder dims larger than
+ some supplied number, e.g. >256, so in effect on those frames we are using
+ a smaller encoer dim.
+
+ We generate the random masks at this level because we want the 2 masks to 'agree'
+ all the way up the encoder stack. This will mean that the 1st mask will have
+ mask values repeated self.zipformer_subsampling_factor times.
+
+ Args:
+ x: the embeddings (needed for the shape and dtype and device), of shape
+ (1, batch_size, encoder_dims0)
+ """
+ num_encoders = len(self.encoder_dim)
+ if not self.training:
+ return [1.0] * num_encoders
+
+ (num_frames0, batch_size, _encoder_dims0) = x.shape
+
+ assert self.encoder_dim[0] == _encoder_dims0, (
+ self.encoder_dim[0],
+ _encoder_dims0,
+ )
+
+ feature_mask_dropout_prob = 0.125
+
+ # mask1 shape: (1, batch_size, 1)
+ mask1 = (
+ torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob
+ ).to(x.dtype)
+
+ # mask2 has additional sequences masked, about twice the number.
+ mask2 = torch.logical_and(
+ mask1,
+ (
+ torch.rand(1, batch_size, 1, device=x.device)
+ > feature_mask_dropout_prob
+ ).to(x.dtype),
+ )
+
+ # dim: (1, batch_size, 2)
+ mask = torch.cat((mask1, mask2), dim=-1)
+
+ feature_masks = []
+ for i in range(num_encoders):
+ channels = self.encoder_dim[i]
+ feature_mask = torch.ones(
+ 1, batch_size, channels, dtype=x.dtype, device=x.device
+ )
+ u1 = self.encoder_unmasked_dim[i]
+ u2 = u1 + (channels - u1) // 2
+
+ feature_mask[:, :, u1:u2] *= mask[..., 0:1]
+ feature_mask[:, :, u2:] *= mask[..., 1:2]
+
+ feature_masks.append(feature_mask)
+
+ return feature_masks
+
+ def get_chunk_info(self) -> Tuple[int, int]:
+ """
+ Returns chunk_size and left_context_chunks.
+ """
+ if not self.causal:
+ return -1, -1
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ assert len(self.chunk_size) == 1, self.chunk_size
+ chunk_size = self.chunk_size[0]
+ else:
+ chunk_size = random.choice(self.chunk_size)
+
+ if chunk_size == -1:
+ left_context_chunks = -1
+ else:
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ assert len(self.left_context_frames) == 1, self.left_context_frames
+ left_context_frames = self.left_context_frames[0]
+ else:
+ left_context_frames = random.choice(self.left_context_frames)
+ # Note: in Python, -1 // n == -1 for n > 0
+ left_context_chunks = left_context_frames // chunk_size
+ if left_context_chunks == 0:
+ left_context_chunks = 1
+
+ return chunk_size, left_context_chunks
+
+ def forward(
+ self,
+ x: Tensor,
+ x_lens: Tensor,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (seq_len, batch_size, feature_dim).
+ x_lens:
+ A tensor of shape (batch_size,) containing the number of frames in
+ `x` before padding.
+ src_key_padding_mask:
+ The mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+ Returns:
+ Return a tuple containing 2 tensors:
+ - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
+ - lengths, a tensor of shape (batch_size,) containing the number
+ of frames in `embeddings` before padding.
+ """
+ outputs = []
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ feature_masks = [1.0] * len(self.encoder_dim)
+ else:
+ feature_masks = self.get_feature_masks(x)
+
+ chunk_size, left_context_chunks = self.get_chunk_info()
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ # Not support exporting a model for simulating streaming decoding
+ attn_mask = None
+ else:
+ attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
+
+ for i, module in enumerate(self.encoders):
+ ds = self.downsampling_factor[i]
+ x = convert_num_channels(x, self.encoder_dim[i])
+
+ x = module(
+ x,
+ chunk_size=chunk_size,
+ feature_mask=feature_masks[i],
+ src_key_padding_mask=(
+ None
+ if src_key_padding_mask is None
+ else src_key_padding_mask[..., ::ds]
+ ),
+ attn_mask=attn_mask,
+ )
+ outputs.append(x)
+
+ # if the last output has the largest dimension, x will be unchanged,
+ # it will be the same as outputs[-1]. Otherwise it will be concatenated
+ # from different pieces of 'outputs', taking each dimension from the
+ # most recent output that has it present.
+ x = self._get_full_dim_output(outputs)
+ x = self.downsample_output(x)
+ # class Downsample has this rounding behavior..
+ assert self.output_downsampling_factor == 2, self.output_downsampling_factor
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ lengths = (x_lens + 1) // 2
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ lengths = (x_lens + 1) // 2
+
+ return x, lengths
+
+ def _get_attn_mask(
+ self, x: Tensor, chunk_size: int, left_context_chunks: int
+ ) -> Optional[Tensor]:
+ """
+ Return None if chunk_size == -1, else return attention mask of shape
+ (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True
+ means a masked position.
+ Args:
+ x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
+ chunk_size: chunk size, must divide
+ """
+ if chunk_size <= 0:
+ return None
+ assert all(chunk_size % d == 0 for d in self.downsampling_factor)
+ if left_context_chunks >= 0:
+ num_encoders = len(self.encoder_dim)
+ assert all(
+ chunk_size * left_context_chunks
+ >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
+ for i in range(num_encoders)
+ )
+ else:
+ left_context_chunks = 1000000
+
+ seq_len = x.shape[0]
+
+ # t is frame index, shape (seq_len,)
+ t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
+ # c is chunk index for each frame, shape (seq_len,)
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ c = t // chunk_size
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ c = t // chunk_size
+ src_c = c
+ tgt_c = c.unsqueeze(-1)
+
+ attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks)
+ if __name__ == "__main__":
+ logging.info(f"attn_mask = {attn_mask}")
+ return attn_mask
+
+ def _get_full_dim_output(self, outputs: List[Tensor]):
+ num_encoders = len(self.encoder_dim)
+ assert len(outputs) == num_encoders
+ output_dim = max(self.encoder_dim)
+ output_pieces = [outputs[-1]]
+ cur_dim = self.encoder_dim[-1]
+ for i in range(num_encoders - 2, -1, -1):
+ d = self.encoder_dim[i]
+ if d > cur_dim:
+ this_output = outputs[i]
+ output_pieces.append(this_output[..., cur_dim:d])
+ cur_dim = d
+ assert cur_dim == output_dim
+ return torch.cat(output_pieces, dim=-1)
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ x_lens: Tensor,
+ states: List[Tensor],
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor, List[Tensor]]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (seq_len, batch_size, feature_dim).
+ x_lens:
+ A tensor of shape (batch_size,) containing the number of frames in
+ `x` before padding.
+ states: list of cached tensors of all encoder layers. For layer-i,
+ states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
+ cached_conv1, cached_conv2).
+ src_key_padding_mask:
+ The mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+ Returns:
+ Return a tuple containing 2 tensors:
+ - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
+ - lengths, a tensor of shape (batch_size,) containing the number
+ of frames in `embeddings` before padding.
+ - updated states
+ """
+ outputs = []
+ new_states = []
+ layer_offset = 0
+
+ for i, module in enumerate(self.encoders):
+ num_layers = module.num_layers
+ ds = self.downsampling_factor[i]
+ x = convert_num_channels(x, self.encoder_dim[i])
+
+ x, new_layer_states = module.streaming_forward(
+ x,
+ states=states[layer_offset * 6 : (layer_offset + num_layers) * 6],
+ left_context_len=self.left_context_frames[0] // ds,
+ src_key_padding_mask=src_key_padding_mask[..., ::ds],
+ )
+ layer_offset += num_layers
+ outputs.append(x)
+ new_states += new_layer_states
+
+ # if the last output has the largest dimension, x will be unchanged,
+ # it will be the same as outputs[-1]. Otherwise it will be concatenated
+ # from different pieces of 'outputs', taking each dimension from the
+ # most recent output that has it present.
+ x = self._get_full_dim_output(outputs)
+ x = self.downsample_output(x)
+ # class Downsample has this rounding behavior..
+ assert self.output_downsampling_factor == 2
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ lengths = (x_lens + 1) // 2
+ else:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ lengths = (x_lens + 1) // 2
+
+ return x, lengths, new_states
+
+ @torch.jit.export
+ def get_init_states(
+ self,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+ ) -> List[Tensor]:
+ """Get initial states.
+
+ A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ """
+ states = []
+ for i, module in enumerate(self.encoders):
+ num_layers = module.num_layers
+ embed_dim = self.encoder_dim[i]
+ ds = self.downsampling_factor[i]
+ num_heads = self.num_heads[i]
+ key_dim = self.query_head_dim[i] * num_heads
+ value_dim = self.value_head_dim[i] * num_heads
+ downsample_left = self.left_context_frames[0] // ds
+ nonlin_attn_head_dim = 3 * embed_dim // 4
+ conv_left_pad = self.cnn_module_kernel[i] // 2
+ for layer in range(num_layers):
+ cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(
+ device
+ )
+ cached_nonlin_attn = torch.zeros(
+ 1, batch_size, downsample_left, nonlin_attn_head_dim
+ ).to(device)
+ cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(
+ device
+ )
+ cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(
+ device
+ )
+ cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
+ device
+ )
+ cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
+ device
+ )
+ states += [
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ]
+
+ return states
+
+
+def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
+ return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x)
+
+
+def _balancer_schedule(min_prob: float):
+ return ScheduledFloat((0.0, 0.4), (8000.0, min_prob))
+
+
+class Zipformer2EncoderLayer(nn.Module):
+ """
+ Args:
+ embed_dim: the number of expected features in the input (required).
+ nhead: the number of heads in the multiheadattention models (required).
+ feedforward_dim: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+ cnn_module_kernel (int): Kernel size of convolution module.
+
+ Examples::
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
+ >>> src = torch.rand(10, 32, 512)
+ >>> pos_emb = torch.rand(32, 19, 512)
+ >>> out = encoder_layer(src, pos_emb)
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ pos_dim: int,
+ num_heads: int,
+ query_head_dim: int,
+ pos_head_dim: int,
+ value_head_dim: int,
+ feedforward_dim: int,
+ dropout: FloatLike = 0.1,
+ cnn_module_kernel: int = 31,
+ causal: bool = False,
+ attention_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
+ ),
+ conv_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
+ ),
+ const_attention_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.25), (4000.0, 0.025), default=0
+ ),
+ ff2_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
+ ),
+ ff3_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
+ ),
+ bypass_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.5), (4000.0, 0.02), default=0
+ ),
+ lora_r: int = 0,
+ lora_alpha: int = 4,
+ lora_dropout: float = 0.0,
+ ) -> None:
+ super(Zipformer2EncoderLayer, self).__init__()
+ self.embed_dim = embed_dim
+
+ # self.bypass implements layer skipping as well as bypass; see its default values.
+ self.bypass = BypassModule(
+ embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
+ )
+ # bypass_mid is bypass used in the middle of the layer.
+ self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
+
+ # skip probability for dynamic modules (meaning: anything but feedforward).
+ self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
+ # an additional skip probability that applies to ConvModule to stop it from
+ # contributing too much early on.
+ self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
+
+ # ff2_skip_rate is to prevent the ff2 module from having output that's too big
+ # compared to its residual.
+ self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
+ self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
+
+ self.const_attention_rate = copy.deepcopy(const_attention_rate)
+
+ self.self_attn_weights = RelPositionMultiheadAttentionWeights(
+ embed_dim,
+ pos_dim=pos_dim,
+ num_heads=num_heads,
+ query_head_dim=query_head_dim,
+ pos_head_dim=pos_head_dim,
+ dropout=0.0,
+ lora_r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+
+ self.self_attn1 = SelfAttention(
+ embed_dim,
+ num_heads,
+ value_head_dim,
+ lora_r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+
+ self.self_attn2 = SelfAttention(
+ embed_dim,
+ num_heads,
+ value_head_dim,
+ lora_r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+
+ self.feed_forward1 = FeedforwardModule(
+ embed_dim,
+ (feedforward_dim * 3) // 4,
+ dropout,
+ lora_r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+
+ self.feed_forward2 = FeedforwardModule(
+ embed_dim,
+ feedforward_dim,
+ dropout,
+ lora_r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+
+ self.feed_forward3 = FeedforwardModule(
+ embed_dim,
+ (feedforward_dim * 5) // 4,
+ dropout,
+ lora_r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+
+ self.nonlin_attention = NonlinAttention(
+ embed_dim, hidden_channels=3 * embed_dim // 4
+ )
+
+ self.conv_module1 = ConvolutionModule(
+ embed_dim, cnn_module_kernel, causal=causal
+ )
+
+ self.conv_module2 = ConvolutionModule(
+ embed_dim, cnn_module_kernel, causal=causal
+ )
+
+ # TODO: remove it
+ self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
+
+ self.norm = BiasNorm(embed_dim)
+
+ self.balancer1 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ min_abs=0.2,
+ max_abs=4.0,
+ )
+
+ # balancer for output of NonlinAttentionModule
+ self.balancer_na = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
+ prob=0.05, # out of concern for memory usage
+ )
+
+ # balancer for output of feedforward2, prevent it from staying too
+ # small. give this a very small probability, even at the start of
+ # training, it's to fix a rare problem and it's OK to fix it slowly.
+ self.balancer_ff2 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
+ max_abs=2.0,
+ prob=0.05,
+ )
+
+ self.balancer_ff3 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
+ max_abs=4.0,
+ prob=0.05,
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(4.0, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.balancer2 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ min_abs=0.1,
+ max_abs=4.0,
+ )
+
+ def get_sequence_dropout_mask(
+ self, x: Tensor, dropout_rate: float
+ ) -> Optional[Tensor]:
+ if (
+ dropout_rate == 0.0
+ or not self.training
+ or torch.jit.is_scripting()
+ or torch.jit.is_tracing()
+ ):
+ return None
+ batch_size = x.shape[1]
+ mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
+ return mask
+
+ def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
+ """
+ Apply sequence-level dropout to x.
+ x shape: (seq_len, batch_size, embed_dim)
+ """
+ dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
+ if dropout_mask is None:
+ return x
+ else:
+ return x * dropout_mask
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ chunk_size: int = -1,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Pass the input through the encoder layer.
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
+ chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+
+ Returns:
+ A tensor which has the same shape as src
+ """
+ src_orig = src
+
+ # dropout rate for non-feedforward submodules
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ attention_skip_rate = 0.0
+ else:
+ attention_skip_rate = (
+ float(self.attention_skip_rate) if self.training else 0.0
+ )
+
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ attn_weights = self.self_attn_weights(
+ src,
+ pos_emb=pos_emb,
+ attn_mask=attn_mask,
+ key_padding_mask=src_key_padding_mask,
+ )
+
+ src = src + self.feed_forward1(src)
+
+ self_attn_dropout_mask = self.get_sequence_dropout_mask(
+ src, attention_skip_rate
+ )
+
+ selected_attn_weights = attn_weights[0:1]
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif not self.training and random.random() < float(self.const_attention_rate):
+ # Make attention weights constant. The intention is to
+ # encourage these modules to do something similar to an
+ # averaging-over-time operation.
+ # only need the mask, can just use the 1st one and expand later
+ selected_attn_weights = selected_attn_weights[0:1]
+ selected_attn_weights = (selected_attn_weights > 0.0).to(
+ selected_attn_weights.dtype
+ )
+ selected_attn_weights = selected_attn_weights * (
+ 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
+ )
+
+ na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
+
+ src = src + (
+ na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
+ )
+
+ self_attn = self.self_attn1(src, attn_weights)
+
+ src = src + (
+ self_attn
+ if self_attn_dropout_mask is None
+ else self_attn * self_attn_dropout_mask
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ conv_skip_rate = 0.0
+ else:
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.conv_module1(
+ src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
+ ),
+ conv_skip_rate,
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ ff2_skip_rate = 0.0
+ else:
+ ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
+ )
+
+ # bypass in the middle of the layer.
+ src = self.bypass_mid(src_orig, src)
+
+ self_attn = self.self_attn2(src, attn_weights)
+
+ src = src + (
+ self_attn
+ if self_attn_dropout_mask is None
+ else self_attn * self_attn_dropout_mask
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ conv_skip_rate = 0.0
+ else:
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.conv_module2(
+ src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
+ ),
+ conv_skip_rate,
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ ff3_skip_rate = 0.0
+ else:
+ ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
+ )
+
+ src = self.balancer1(src)
+ src = self.norm(src)
+
+ src = self.bypass(src_orig, src)
+
+ src = self.balancer2(src)
+ src = self.whiten(src)
+
+ return src
+
+ def streaming_forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ cached_key: Tensor,
+ cached_nonlin_attn: Tensor,
+ cached_val1: Tensor,
+ cached_val2: Tensor,
+ cached_conv1: Tensor,
+ cached_conv2: Tensor,
+ left_context_len: int,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
+ """Pass the input through the encoder layer in streaming forward mode.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or
+ (batch_size, left_context_len+2*seq_len-1, pos_emb_dim)
+ cached_key: cached attention key tensor of left context,
+ of shape (left_context_len, batch_size, key_dim)
+ cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape
+ (num_heads, batch_size, left_context_len, head_dim)
+ cached_val1: cached left context for the first attention module,
+ of shape (left_context_len, batch_size, value_dim)
+ cached_val2: cached left context for the second attention module,
+ of shape (left_context_len, batch_size, value_dim)
+ cached_conv1: cached left context for the first convolution module,
+ of shape (batch_size, channels, left_pad)
+ cached_conv2: cached left context for the second convolution module,
+ of shape (batch_size, channels, left_pad)
+ left_context_len: number of left context frames.
+ src_key_padding_mask: the mask for padding, of shape
+ (batch_size, left_context_len + seq_len); True means masked position.
+ May be None.
+
+ Returns:
+ - x, with the same shape as src
+ - updated cached_key
+ - updated cached_nonlin_attn
+ - updated cached_val1
+ - updated cached_val2
+ - updated cached_conv1
+ - updated cached_conv2
+ """
+ src_orig = src
+
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ attn_weights, cached_key = self.self_attn_weights.streaming_forward(
+ src,
+ pos_emb=pos_emb,
+ cached_key=cached_key,
+ left_context_len=left_context_len,
+ key_padding_mask=src_key_padding_mask,
+ )
+
+ src = src + self.feed_forward1(src)
+
+ na, cached_nonlin_attn = self.nonlin_attention.streaming_forward(
+ src,
+ attn_weights[0:1],
+ cached_x=cached_nonlin_attn,
+ left_context_len=left_context_len,
+ )
+ src = src + na
+
+ self_attn, cached_val1 = self.self_attn1.streaming_forward(
+ src,
+ attn_weights=attn_weights,
+ cached_val=cached_val1,
+ left_context_len=left_context_len,
+ )
+ src = src + self_attn
+
+ src_conv, cached_conv1 = self.conv_module1.streaming_forward(
+ src,
+ cache=cached_conv1,
+ src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
+ )
+ src = src + src_conv
+
+ src = src + self.feed_forward2(src)
+
+ # bypass in the middle of the layer.
+ src = self.bypass_mid(src_orig, src)
+
+ self_attn, cached_val2 = self.self_attn2.streaming_forward(
+ src,
+ attn_weights=attn_weights,
+ cached_val=cached_val2,
+ left_context_len=left_context_len,
+ )
+ src = src + self_attn
+
+ src_conv, cached_conv2 = self.conv_module2.streaming_forward(
+ src,
+ cache=cached_conv2,
+ src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
+ )
+ src = src + src_conv
+
+ src = src + self.feed_forward3(src)
+
+ src = self.norm(src)
+
+ src = self.bypass(src_orig, src)
+
+ return (
+ src,
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ )
+
+
+class Zipformer2Encoder(nn.Module):
+ r"""Zipformer2Encoder is a stack of N encoder layers
+
+ Args:
+ encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
+ num_layers: the number of sub-encoder-layers in the encoder (required).
+ pos_dim: the dimension for the relative positional encoding
+
+ Examples::
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
+ >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
+ >>> src = torch.rand(10, 32, 512)
+ >>> out = zipformer_encoder(src)
+ """
+
+ def __init__(
+ self,
+ encoder_layer: nn.Module,
+ num_layers: int,
+ pos_dim: int,
+ dropout: float,
+ warmup_begin: float,
+ warmup_end: float,
+ initial_layerdrop_rate: float = 0.5,
+ final_layerdrop_rate: float = 0.05,
+ ) -> None:
+ super().__init__()
+ self.encoder_pos = CompactRelPositionalEncoding(
+ pos_dim, dropout_rate=0.15, length_factor=1.0
+ )
+
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+ )
+ self.num_layers = num_layers
+
+ assert 0 <= warmup_begin <= warmup_end
+
+ delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
+ cur_begin = warmup_begin # interpreted as a training batch index
+ for i in range(num_layers):
+ cur_end = cur_begin + delta
+ self.layers[i].bypass.skip_rate = ScheduledFloat(
+ (cur_begin, initial_layerdrop_rate),
+ (cur_end, final_layerdrop_rate),
+ default=0.0,
+ )
+ cur_begin = cur_end
+
+ def forward(
+ self,
+ src: Tensor,
+ chunk_size: int = -1,
+ feature_mask: Union[Tensor, float] = 1.0,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+
+ Returns: a Tensor with the same shape as src.
+ """
+ pos_emb = self.encoder_pos(src)
+ output = src
+
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ output = output * feature_mask
+
+ for i, mod in enumerate(self.layers):
+ output = mod(
+ output,
+ pos_emb,
+ chunk_size=chunk_size,
+ attn_mask=attn_mask,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+ output = output * feature_mask
+
+ return output
+
+ def streaming_forward(
+ self,
+ src: Tensor,
+ states: List[Tensor],
+ left_context_len: int,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, List[Tensor]]:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
+ (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ left_context_len: Number of left context frames.
+ src_key_padding_mask: the mask for padding, of shape
+ (batch_size, left_context_len + seq_len); True means masked position.
+ May be None.
+
+ Returns:
+ - output, a Tensor with the same shape as src.
+ - updated states
+ """
+ pos_emb = self.encoder_pos(src, left_context_len)
+ output = src
+
+ new_states = []
+ for i, mod in enumerate(self.layers):
+ (
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ) = states[i * 6 : (i + 1) * 6]
+ (
+ output,
+ new_cached_key,
+ new_cached_nonlin_attn,
+ new_cached_val1,
+ new_cached_val2,
+ new_cached_conv1,
+ new_cached_conv2,
+ ) = mod.streaming_forward(
+ output,
+ pos_emb,
+ cached_key=cached_key,
+ cached_nonlin_attn=cached_nonlin_attn,
+ cached_val1=cached_val1,
+ cached_val2=cached_val2,
+ cached_conv1=cached_conv1,
+ cached_conv2=cached_conv2,
+ left_context_len=left_context_len,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ new_states += [
+ new_cached_key,
+ new_cached_nonlin_attn,
+ new_cached_val1,
+ new_cached_val2,
+ new_cached_conv1,
+ new_cached_conv2,
+ ]
+
+ return output, new_states
+
+
+class BypassModule(nn.Module):
+ """
+ An nn.Module that implements a learnable bypass scale, and also randomized per-sequence
+ layer-skipping. The bypass is limited during early stages of training to be close to
+ "straight-through", i.e. to not do the bypass operation much initially, in order to
+ force all the modules to learn something.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ skip_rate: FloatLike = 0.0,
+ straight_through_rate: FloatLike = 0.0,
+ scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
+ scale_max: FloatLike = 1.0,
+ ):
+ super().__init__()
+ self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
+ self.skip_rate = copy.deepcopy(skip_rate)
+ self.straight_through_rate = copy.deepcopy(straight_through_rate)
+ self.scale_min = copy.deepcopy(scale_min)
+ self.scale_max = copy.deepcopy(scale_max)
+
+ def _get_bypass_scale(self, batch_size: int):
+ # returns bypass-scale of shape (num_channels,),
+ # or (batch_size, num_channels,). This is actually the
+ # scale on the non-residual term, so 0 correponds to bypassing
+ # this module.
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
+ return self.bypass_scale
+ else:
+ ans = limit_param_value(
+ self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max)
+ )
+ skip_rate = float(self.skip_rate)
+ if skip_rate != 0.0:
+ mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
+ ans = ans * mask
+ # now ans is of shape (batch_size, num_channels), and is zero for sequences
+ # on which we have randomly chosen to do layer-skipping.
+ straight_through_rate = float(self.straight_through_rate)
+ if straight_through_rate != 0.0:
+ mask = (
+ torch.rand((batch_size, 1), device=ans.device)
+ < straight_through_rate
+ )
+ ans = torch.maximum(ans, mask.to(ans.dtype))
+ return ans
+
+ def forward(self, src_orig: Tensor, src: Tensor):
+ """
+ Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
+ Returns: something with the same shape as src and src_orig
+ """
+ bypass_scale = self._get_bypass_scale(src.shape[1])
+ return src_orig + (src - src_orig) * bypass_scale
+
+
+class DownsampledZipformer2Encoder(nn.Module):
+ r"""
+ DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
+ after convolutional downsampling, and then upsampled again at the output, and combined
+ with the origin input, so that the output has the same shape as the input.
+ """
+
+ def __init__(
+ self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike
+ ):
+ super(DownsampledZipformer2Encoder, self).__init__()
+ self.downsample_factor = downsample
+ self.downsample = SimpleDownsample(dim, downsample, dropout)
+ self.num_layers = encoder.num_layers
+ self.encoder = encoder
+ self.upsample = SimpleUpsample(dim, downsample)
+ self.out_combiner = BypassModule(dim, straight_through_rate=0)
+
+ def forward(
+ self,
+ src: Tensor,
+ chunk_size: int = -1,
+ feature_mask: Union[Tensor, float] = 1.0,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Downsample, go through encoder, upsample.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+
+ Returns: a Tensor with the same shape as src.
+ """
+ src_orig = src
+ src = self.downsample(src)
+ ds = self.downsample_factor
+ if attn_mask is not None:
+ attn_mask = attn_mask[::ds, ::ds]
+
+ src = self.encoder(
+ src,
+ chunk_size=chunk_size // ds,
+ feature_mask=feature_mask,
+ attn_mask=attn_mask,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ src = self.upsample(src)
+ # remove any extra frames that are not a multiple of downsample_factor
+ src = src[: src_orig.shape[0]]
+
+ return self.out_combiner(src_orig, src)
+
+ def streaming_forward(
+ self,
+ src: Tensor,
+ states: List[Tensor],
+ left_context_len: int,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, List[Tensor]]:
+ r"""Downsample, go through encoder, upsample, in streaming forward mode.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
+ (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ left_context_len: Number of left context frames.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len);
+ True means masked position. May be None.
+
+ Returns:
+ - output, a Tensor with the same shape as src.
+ - updated states
+ """
+ src_orig = src
+ src = self.downsample(src)
+
+ src, new_states = self.encoder.streaming_forward(
+ src,
+ states=states,
+ left_context_len=left_context_len,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ src = self.upsample(src)
+ # remove any extra frames that are not a multiple of downsample_factor
+ src = src[: src_orig.shape[0]]
+
+ return self.out_combiner(src_orig, src), new_states
+
+
+class SimpleDownsample(torch.nn.Module):
+ """
+ Does downsampling with attention, by weighted sum, and a projection..
+ """
+
+ def __init__(self, channels: int, downsample: int, dropout: FloatLike):
+ super(SimpleDownsample, self).__init__()
+
+ self.bias = nn.Parameter(torch.zeros(downsample))
+
+ self.name = None # will be set from training code
+ self.dropout = copy.deepcopy(dropout)
+
+ self.downsample = downsample
+
+ def forward(self, src: Tensor) -> Tensor:
+ """
+ x: (seq_len, batch_size, in_channels)
+ Returns a tensor of shape
+ ( (seq_len+downsample-1)//downsample, batch_size, channels)
+ """
+ (seq_len, batch_size, in_channels) = src.shape
+ ds = self.downsample
+ d_seq_len = (seq_len + ds - 1) // ds
+
+ # Pad to an exact multiple of self.downsample
+ # right-pad src, repeating the last element.
+ pad = d_seq_len * ds - seq_len
+ src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
+ src = torch.cat((src, src_extra), dim=0)
+ assert src.shape[0] == d_seq_len * ds
+
+ src = src.reshape(d_seq_len, ds, batch_size, in_channels)
+
+ weights = self.bias.softmax(dim=0)
+ # weights: (downsample, 1, 1)
+ weights = weights.unsqueeze(-1).unsqueeze(-1)
+
+ # ans1 is the first `in_channels` channels of the output
+ ans = (src * weights).sum(dim=1)
+
+ return ans
+
+
+class SimpleUpsample(torch.nn.Module):
+ """
+ A very simple form of upsampling that mostly just repeats the input, but
+ also adds a position-specific bias.
+ """
+
+ def __init__(self, num_channels: int, upsample: int):
+ super(SimpleUpsample, self).__init__()
+ self.upsample = upsample
+
+ def forward(self, src: Tensor) -> Tensor:
+ """
+ x: (seq_len, batch_size, num_channels)
+ Returns a tensor of shape
+ ( (seq_len*upsample), batch_size, num_channels)
+ """
+ upsample = self.upsample
+ (seq_len, batch_size, num_channels) = src.shape
+ src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
+ src = src.reshape(seq_len * upsample, batch_size, num_channels)
+ return src
+
+
+class CompactRelPositionalEncoding(torch.nn.Module):
+ """
+ Relative positional encoding module. This version is "compact" meaning it is able to encode
+ the important information about the relative position in a relatively small number of dimensions.
+ The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001)
+ make very little difference to the embedding. Such differences were potentially important
+ when encoding absolute position, but not important when encoding relative position because there
+ is now no need to compare two large offsets with each other.
+
+ Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval
+ using the atan() function, before doing the fourier transform of that fixed interval. The
+ atan() function would compress the "long tails" too small,
+ making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic
+ function to compress large offsets to a smaller range before applying atan().
+ Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long
+ as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim)
+
+
+ Args:
+ embed_dim: Embedding dimension.
+ dropout_rate: Dropout rate.
+ max_len: Maximum input length: just a heuristic for initialization.
+ length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
+ less weight to small differences of offset near the origin.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ dropout_rate: FloatLike,
+ max_len: int = 1000,
+ length_factor: float = 1.0,
+ ) -> None:
+ """Construct a CompactRelPositionalEncoding object."""
+ super(CompactRelPositionalEncoding, self).__init__()
+ self.embed_dim = embed_dim
+ assert embed_dim % 2 == 0
+ self.dropout = Dropout2(dropout_rate)
+ self.pe = None
+ assert length_factor >= 1.0
+ self.length_factor = length_factor
+ self.extend_pe(torch.tensor(0.0).expand(max_len))
+
+ def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
+ """Reset the positional encodings."""
+ T = x.size(0) + left_context_len
+
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(0) >= T * 2 - 1:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+
+ # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
+ x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
+
+ freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
+
+ # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution
+ # for small time offsets but less resolution for large time offsets.
+ compression_length = self.embed_dim**0.5
+ # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity;
+ # but it does so more slowly than T for large absolute values of T.
+ # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which
+ # is important.
+ x_compressed = (
+ compression_length
+ * x.sign()
+ * ((x.abs() + compression_length).log() - math.log(compression_length))
+ )
+
+ # if self.length_factor == 1.0, then length_scale is chosen so that the
+ # FFT can exactly separate points close to the origin (T == 0). So this
+ # part of the formulation is not really heuristic.
+ # But empirically, for ASR at least, length_factor > 1.0 seems to work better.
+ length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
+
+ # note for machine implementations: if atan is not available, we can use:
+ # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
+ # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
+ x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
+
+ cosines = (x_atan * freqs).cos()
+ sines = (x_atan * freqs).sin()
+
+ pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
+ pe[:, 0::2] = cosines
+ pe[:, 1::2] = sines
+ pe[:, -1] = 1.0 # for bias.
+
+ self.pe = pe.to(dtype=x.dtype)
+
+ def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
+ """Create positional encoding.
+
+ Args:
+ x (Tensor): Input tensor (time, batch, `*`).
+ left_context_len: (int): Length of cached left context.
+
+ Returns:
+ positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
+ """
+ self.extend_pe(x, left_context_len)
+ x_size_left = x.size(0) + left_context_len
+ # length of positive side: x.size(0) + left_context_len
+ # length of negative side: x.size(0)
+ pos_emb = self.pe[
+ self.pe.size(0) // 2
+ - x_size_left
+ + 1 : self.pe.size(0) // 2 # noqa E203
+ + x.size(0),
+ :,
+ ]
+ pos_emb = pos_emb.unsqueeze(0)
+ return self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttentionWeights(nn.Module):
+ r"""Module that computes multi-head attention weights with relative position encoding.
+ Various other modules consume the resulting attention weights: see, for example, the
+ SimpleAttention module which allows you to compute conventional attention.
+
+ This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context",
+ we have to write up the differences.
+
+
+ Args:
+ embed_dim: number of channels at the input to this module, e.g. 256
+ pos_dim: dimension of the positional encoding vectors, e.g. 128.
+ num_heads: number of heads to compute weights for, e.g. 8
+ query_head_dim: dimension of the query (and key), per head. e.g. 24.
+ pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
+ dropout: dropout probability for attn_output_weights. Default: 0.0.
+ pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
+ any given call to forward(), in training time.
+ lora_r: the bottleneck dimension of LoRA
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ pos_dim: int,
+ num_heads: int,
+ query_head_dim: int,
+ pos_head_dim: int,
+ dropout: float = 0.0,
+ pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
+ lora_r: int = 0,
+ lora_alpha: int = 4,
+ lora_dropout: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.query_head_dim = query_head_dim
+ self.pos_head_dim = pos_head_dim
+ self.dropout = dropout
+ self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
+ self.name = None # will be overwritten in training code; for diagnostics.
+
+ key_head_dim = query_head_dim
+ in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
+
+ # the initial_scale is supposed to take over the "scaling" factor of
+ # head_dim ** -0.5 that has been used in previous forms of attention,
+ # dividing it between the query and key. Note: this module is intended
+ # to be used with the ScaledAdam optimizer; with most other optimizers,
+ # it would be necessary to apply the scaling factor in the forward function.
+ # self.in_proj = ScaledLinear(
+ # embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
+ # )
+ self.in_proj = ScaledLinear_lora(
+ in_features=embed_dim,
+ out_features=in_proj_dim,
+ r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ initial_scale=query_head_dim**-0.25,
+ bias=True,
+ )
+
+ self.whiten_keys = Whiten(
+ num_groups=num_heads,
+ whitening_limit=_whitening_schedule(3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.025,
+ )
+
+ # add a balancer for the keys that runs with very small probability, and
+ # tries to enforce that all dimensions have mean around zero. The
+ # weights produced by this module are invariant to adding a constant to
+ # the keys, so the derivative of the bias is mathematically zero; but
+ # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
+ # bias because the small numerical roundoff tends to have a non-random
+ # sign. This module is intended to prevent that. Use a very small
+ # probability; that should be suffixient to fix the problem.
+ self.balance_keys = Balancer(
+ key_head_dim * num_heads,
+ channel_dim=-1,
+ min_positive=0.4,
+ max_positive=0.6,
+ min_abs=0.0,
+ max_abs=100.0,
+ prob=0.025,
+ )
+
+ # linear transformation for positional encoding.
+ self.linear_pos = ScaledLinear(
+ pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05
+ )
+
+ # the following are for diagnosics only, see --print-diagnostics option
+ self.copy_pos_query = Identity()
+ self.copy_query = Identity()
+
+ def forward(
+ self,
+ x: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""
+ Args:
+ x: input of shape (seq_len, batch_size, embed_dim)
+ pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
+ key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
+ are True in this mask will be ignored as sources in the attention weighting.
+ attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
+ interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
+ saying which positions are allowed to attend to which other positions.
+ Returns:
+ a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
+ interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
+ """
+ x = self.in_proj(x)
+ query_head_dim = self.query_head_dim
+ pos_head_dim = self.pos_head_dim
+ num_heads = self.num_heads
+
+ seq_len, batch_size, _ = x.shape
+
+ query_dim = query_head_dim * num_heads
+
+ # self-attention
+ q = x[..., 0:query_dim]
+ k = x[..., query_dim : 2 * query_dim]
+ # p is the position-encoding query
+ p = x[..., 2 * query_dim :]
+ assert p.shape[-1] == num_heads * pos_head_dim
+
+ q = self.copy_query(q) # for diagnostics only, does nothing.
+ k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
+ p = self.copy_pos_query(p) # for diagnostics only, does nothing.
+
+ q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
+ p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
+ k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
+
+ # time1 refers to target, time2 refers to source.
+ q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
+ p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
+ k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
+
+ attn_scores = torch.matmul(q, k)
+
+ use_pos_scores = False
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ # We can't put random.random() in the same line
+ use_pos_scores = True
+ elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
+ use_pos_scores = True
+
+ if use_pos_scores:
+ pos_emb = self.linear_pos(pos_emb)
+ seq_len2 = 2 * seq_len - 1
+ pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
+ 2, 0, 3, 1
+ )
+ # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
+
+ # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
+ # [where seq_len2 represents relative position.]
+ pos_scores = torch.matmul(p, pos_emb)
+ # the following .as_strided() expression converts the last axis of pos_scores from relative
+ # to absolute position. I don't know whether I might have got the time-offsets backwards or
+ # not, but let this code define which way round it is supposed to be.
+ if torch.jit.is_tracing():
+ (num_heads, batch_size, time1, n) = pos_scores.shape
+ rows = torch.arange(start=time1 - 1, end=-1, step=-1)
+ cols = torch.arange(seq_len)
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+ indexes = rows + cols
+ pos_scores = pos_scores.reshape(-1, n)
+ pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
+ pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
+ else:
+ pos_scores = pos_scores.as_strided(
+ (num_heads, batch_size, seq_len, seq_len),
+ (
+ pos_scores.stride(0),
+ pos_scores.stride(1),
+ pos_scores.stride(2) - pos_scores.stride(3),
+ pos_scores.stride(3),
+ ),
+ storage_offset=pos_scores.stride(3) * (seq_len - 1),
+ )
+
+ attn_scores = attn_scores + pos_scores
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif self.training and random.random() < 0.1:
+ # This is a harder way of limiting the attention scores to not be
+ # too large. It incurs a penalty if any of them has an absolute
+ # value greater than 50.0. this should be outside the normal range
+ # of the attention scores. We use this mechanism instead of, say,
+ # something added to the loss function involving the entropy,
+ # because once the entropy gets very small gradients through the
+ # softmax can become very small, and we'd get zero derivatives. The
+ # choices of 1.0e-04 as the scale on the penalty makes this
+ # mechanism vulnerable to the absolute scale of the loss function,
+ # but we view this as a failsafe to avoid "implausible" parameter
+ # values rather than a regularization method that should be active
+ # under normal circumstances.
+ attn_scores = penalize_abs_values_gt(
+ attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
+ )
+
+ assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ if attn_mask is not None:
+ assert attn_mask.dtype == torch.bool
+ # use -1000 to avoid nan's where attn_mask and key_padding_mask make
+ # all scores zero. It's important that this be large enough that exp(-1000)
+ # is exactly zero, for reasons related to const_attention_rate, it
+ # compares the final weights with zero.
+ attn_scores = attn_scores.masked_fill(attn_mask, -1000)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (
+ batch_size,
+ seq_len,
+ ), key_padding_mask.shape
+ attn_scores = attn_scores.masked_fill(
+ key_padding_mask.unsqueeze(1),
+ -1000,
+ )
+
+ # We use our own version of softmax, defined in scaling.py, which should
+ # save a little of the memory used in backprop by, if we are in
+ # automatic mixed precision mode (amp / autocast), by only storing the
+ # half-precision output for backprop purposes.
+ attn_weights = softmax(attn_scores, dim=-1)
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif random.random() < 0.001 and not self.training:
+ self._print_attn_entropy(attn_weights)
+
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ )
+
+ return attn_weights
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ pos_emb: Tensor,
+ cached_key: Tensor,
+ left_context_len: int,
+ key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ r"""
+ Args:
+ x: input of shape (seq_len, batch_size, embed_dim)
+ pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim)
+ cached_key: cached attention key tensor of left context,
+ of shape (left_context_len, batch_size, key_dim)
+ left_context_len: number of left context frames.
+ key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
+ are True in this mask will be ignored as sources in the attention weighting.
+
+ Returns:
+ - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2),
+ interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
+ - updated cached attention key tensor of left context.
+ """
+ x = self.in_proj(x)
+ query_head_dim = self.query_head_dim
+ pos_head_dim = self.pos_head_dim
+ num_heads = self.num_heads
+
+ seq_len, batch_size, _ = x.shape
+
+ query_dim = query_head_dim * num_heads
+
+ # self-attention
+ q = x[..., 0:query_dim]
+ k = x[..., query_dim : 2 * query_dim]
+ # p is the position-encoding query
+ p = x[..., 2 * query_dim :]
+ assert p.shape[-1] == num_heads * pos_head_dim
+
+ # Pad cached left contexts
+ assert cached_key.shape[0] == left_context_len, (
+ cached_key.shape[0],
+ left_context_len,
+ )
+ k = torch.cat([cached_key, k], dim=0)
+ # Update cached left contexts
+ cached_key = k[-left_context_len:, ...]
+
+ # The length of key
+ k_len = k.shape[0]
+
+ q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
+ p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
+ k = k.reshape(k_len, batch_size, num_heads, query_head_dim)
+
+ # time1 refers to target, time2 refers to source.
+ q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
+ p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
+ k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
+
+ attn_scores = torch.matmul(q, k)
+
+ pos_emb = self.linear_pos(pos_emb)
+ seq_len2 = 2 * seq_len - 1 + left_context_len
+ pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
+ 2, 0, 3, 1
+ )
+ # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
+
+ # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
+ # [where seq_len2 represents relative position.]
+ pos_scores = torch.matmul(p, pos_emb)
+
+ if torch.jit.is_tracing():
+ (num_heads, batch_size, time1, n) = pos_scores.shape
+ rows = torch.arange(start=time1 - 1, end=-1, step=-1)
+ cols = torch.arange(k_len)
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+ indexes = rows + cols
+ pos_scores = pos_scores.reshape(-1, n)
+ pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
+ pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
+ # the following .as_strided() expression converts the last axis of pos_scores from relative
+ # to absolute position. I don't know whether I might have got the time-offsets backwards or
+ # not, but let this code define which way round it is supposed to be.
+ else:
+ pos_scores = pos_scores.as_strided(
+ (num_heads, batch_size, seq_len, k_len),
+ (
+ pos_scores.stride(0),
+ pos_scores.stride(1),
+ pos_scores.stride(2) - pos_scores.stride(3),
+ pos_scores.stride(3),
+ ),
+ storage_offset=pos_scores.stride(3) * (seq_len - 1),
+ )
+
+ attn_scores = attn_scores + pos_scores
+
+ assert attn_scores.shape == (
+ num_heads,
+ batch_size,
+ seq_len,
+ k_len,
+ ), attn_scores.shape
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape
+ attn_scores = attn_scores.masked_fill(
+ key_padding_mask.unsqueeze(1),
+ -1000,
+ )
+
+ attn_weights = attn_scores.softmax(dim=-1)
+
+ return attn_weights, cached_key
+
+ def _print_attn_entropy(self, attn_weights: Tensor):
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
+
+ with torch.no_grad():
+ with torch.cuda.amp.autocast(enabled=False):
+ attn_weights = attn_weights.to(torch.float32)
+ attn_weights_entropy = (
+ -((attn_weights + 1.0e-20).log() * attn_weights)
+ .sum(dim=-1)
+ .mean(dim=(1, 2))
+ )
+ logging.info(
+ f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}"
+ )
+
+
+class SelfAttention(nn.Module):
+ """
+ The simplest possible attention module. This one works with already-computed attention
+ weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
+
+ Args:
+ embed_dim: the input and output embedding dimension
+ num_heads: the number of attention heads
+ value_head_dim: the value dimension per head
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ value_head_dim: int,
+ lora_r: int = 0,
+ lora_alpha: int = 4,
+ lora_dropout: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.in_proj = ScaledLinear_lora(
+ in_features=embed_dim,
+ out_features=num_heads * value_head_dim,
+ r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ bias=True,
+ )
+
+ self.out_proj = ScaledLinear(
+ num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ ) -> Tensor:
+ """
+ Args:
+ x: input tensor, of shape (seq_len, batch_size, embed_dim)
+ attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
+ with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
+ attn_weights.sum(dim=-1) == 1.
+ Returns:
+ a tensor with the same shape as x.
+ """
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, value_head_dim)
+ value_head_dim = x.shape[-1]
+
+ # todo: see whether there is benefit in overriding matmul
+ x = torch.matmul(attn_weights, x)
+ # v: (num_heads, batch_size, seq_len, value_head_dim)
+
+ x = (
+ x.permute(2, 1, 0, 3)
+ .contiguous()
+ .view(seq_len, batch_size, num_heads * value_head_dim)
+ )
+
+ # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
+ x = self.out_proj(x)
+ x = self.whiten(x)
+
+ return x
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ cached_val: Tensor,
+ left_context_len: int,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ x: input tensor, of shape (seq_len, batch_size, embed_dim)
+ attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
+ with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
+ attn_weights.sum(dim=-1) == 1.
+ cached_val: cached attention value tensor of left context,
+ of shape (left_context_len, batch_size, value_dim)
+ left_context_len: number of left context frames.
+
+ Returns:
+ - attention weighted output, a tensor with the same shape as x.
+ - updated cached attention value tensor of left context.
+ """
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ seq_len2 = seq_len + left_context_len
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2)
+
+ x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
+
+ # Pad cached left contexts
+ assert cached_val.shape[0] == left_context_len, (
+ cached_val.shape[0],
+ left_context_len,
+ )
+ x = torch.cat([cached_val, x], dim=0)
+ # Update cached left contexts
+ cached_val = x[-left_context_len:, ...]
+
+ x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, value_head_dim)
+ value_head_dim = x.shape[-1]
+
+ # todo: see whether there is benefit in overriding matmul
+ x = torch.matmul(attn_weights, x)
+ # v: (num_heads, batch_size, seq_len, value_head_dim)
+
+ x = (
+ x.permute(2, 1, 0, 3)
+ .contiguous()
+ .view(seq_len, batch_size, num_heads * value_head_dim)
+ )
+
+ # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
+ x = self.out_proj(x)
+
+ return x, cached_val
+
+
+class FeedforwardModule(nn.Module):
+ """Feedforward module in Zipformer2 model."""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ feedforward_dim: int,
+ dropout: FloatLike,
+ lora_r: int = 0,
+ lora_alpha: int = 4,
+ lora_dropout: float = 0.0,
+ ):
+ super(FeedforwardModule, self).__init__()
+ self.in_proj = ScaledLinear_lora(
+ in_features=embed_dim,
+ out_features=feedforward_dim,
+ r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ bias=True,
+ )
+
+ self.hidden_balancer = Balancer(
+ feedforward_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=1.0,
+ min_abs=0.75,
+ max_abs=5.0,
+ )
+
+ # shared_dim=0 means we share the dropout mask along the time axis
+ self.out_proj = ActivationDropoutAndLinear_lora(
+ feedforward_dim,
+ embed_dim,
+ activation="SwooshL",
+ dropout_p=dropout,
+ dropout_shared_dim=0,
+ bias=True,
+ r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ initial_scale=0.1,
+ )
+
+ self.out_whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(self, x: Tensor):
+ x = self.in_proj(x)
+ x = self.hidden_balancer(x)
+ # out_proj contains SwooshL activation, then dropout, then linear.
+ x = self.out_proj(x)
+ x = self.out_whiten(x)
+ return x
+
+
+class NonlinAttention(nn.Module):
+ """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
+ from the attention module) in place of actual convolution. We also took out the second nonlinearity, the
+ one after the attention mechanism.
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ hidden_channels: int,
+ ) -> None:
+ super().__init__()
+
+ self.hidden_channels = hidden_channels
+
+ self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
+
+ # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0,
+ # because we noticed that well-trained instances of this module have abs-value before the sigmoid
+ # starting from about 3, and poorly-trained instances of the module have smaller abs values
+ # before the sigmoid.
+ self.balancer = Balancer(
+ hidden_channels,
+ channel_dim=-1,
+ min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
+ max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
+ min_abs=0.5,
+ max_abs=5.0,
+ )
+ self.tanh = nn.Tanh()
+
+ self.identity1 = Identity() # for diagnostics.
+ self.identity2 = Identity() # for diagnostics.
+ self.identity3 = Identity() # for diagnostics.
+
+ self.out_proj = ScaledLinear(
+ hidden_channels, channels, bias=True, initial_scale=0.05
+ )
+
+ self.whiten1 = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(5.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.whiten2 = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(5.0, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ ) -> Tensor:
+ """.
+ Args:
+ x: a Tensor of shape (seq_len, batch_size, num_channels)
+ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
+ Returns:
+ a Tensor with the same shape as x
+ """
+ x = self.in_proj(x)
+
+ (seq_len, batch_size, _) = x.shape
+ hidden_channels = self.hidden_channels
+
+ s, x, y = x.chunk(3, dim=2)
+
+ # s will go through tanh.
+
+ s = self.balancer(s)
+ s = self.tanh(s)
+
+ s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
+ x = self.whiten1(x)
+ x = x * s
+ x = self.identity1(x) # diagnostics only, it's the identity.
+
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = torch.matmul(attn_weights, x)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
+
+ y = self.identity2(y)
+ x = x * y
+ x = self.identity3(x)
+
+ x = self.out_proj(x)
+ x = self.whiten2(x)
+ return x
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ cached_x: Tensor,
+ left_context_len: int,
+ ) -> Tuple[Tensor, Tensor]:
+ """.
+ Args:
+ x: a Tensor of shape (seq_len, batch_size, num_channels)
+ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
+ cached_x: left context, a Tensor of shape
+ (num_heads, batch_size, left_context_len, head_dim)
+ left_context_len: number of left context frames.
+ Returns:
+ - a Tensor with the same shape as x
+ - updated left context with same shape as cached_x
+ """
+ x = self.in_proj(x)
+
+ (seq_len, batch_size, _) = x.shape
+ hidden_channels = self.hidden_channels
+
+ s, x, y = x.chunk(3, dim=2)
+
+ # s will go through tanh.
+ s = self.tanh(s)
+
+ s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
+ x = x * s
+
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (
+ num_heads,
+ batch_size,
+ seq_len,
+ left_context_len + seq_len,
+ )
+
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+
+ # Pad cached tensor
+ assert cached_x.shape[2] == left_context_len, (
+ cached_x.shape[2],
+ left_context_len,
+ )
+ x_pad = torch.cat([cached_x, x], dim=2)
+ # Update cached tensor
+ cached_x = x_pad[:, :, -left_context_len:, :]
+
+ x = torch.matmul(attn_weights, x_pad)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
+
+ x = x * y
+
+ x = self.out_proj(x)
+ return x, cached_x
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Zipformer2 model.
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ bias (bool): Whether to use bias in conv layers (default=True).
+
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ causal: bool,
+ ) -> None:
+ """Construct a ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+
+ bottleneck_dim = channels
+ self.causal = causal
+
+ self.in_proj = nn.Linear(
+ channels,
+ 2 * bottleneck_dim,
+ )
+ # the gradients on in_proj are a little noisy, likely to do with the
+ # sigmoid in glu.
+
+ # after in_proj we put x through a gated linear unit (nn.functional.glu).
+ # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
+ # but sometimes, for some reason, for layer 0 the rms ends up being very large,
+ # between 50 and 100 for different channels. This will cause very peaky and
+ # sparse derivatives for the sigmoid gating function, which will tend to make
+ # the loss function not learn effectively. (for most layers the average absolute values
+ # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
+ # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
+ # layers, which likely breaks down as 0.5 for the "linear" half and
+ # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
+ # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
+ # it will be in a better position to start learning something, i.e. to latch onto
+ # the correct range.
+ self.balancer1 = Balancer(
+ bottleneck_dim,
+ channel_dim=-1,
+ min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
+ max_positive=1.0,
+ min_abs=1.5,
+ max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
+ )
+
+ self.activation1 = Identity() # for diagnostics
+
+ self.sigmoid = nn.Sigmoid()
+
+ self.activation2 = Identity() # for diagnostics
+
+ assert kernel_size % 2 == 1
+
+ self.depthwise_conv = (
+ ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size)
+ if causal
+ else nn.Conv1d(
+ in_channels=bottleneck_dim,
+ out_channels=bottleneck_dim,
+ groups=bottleneck_dim,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ )
+ )
+
+ self.balancer2 = Balancer(
+ bottleneck_dim,
+ channel_dim=1,
+ min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
+ max_positive=1.0,
+ min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
+ max_abs=10.0,
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.out_proj = ActivationDropoutAndLinear(
+ bottleneck_dim,
+ channels,
+ activation="SwooshR",
+ dropout_p=0.0,
+ initial_scale=0.05,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ src_key_padding_mask: Optional[Tensor] = None,
+ chunk_size: int = -1,
+ ) -> Tensor:
+ """Compute convolution module.
+
+ Args:
+ x: Input tensor (#time, batch, channels).
+ src_key_padding_mask: the mask for the src keys per batch (optional):
+ (batch, #time), contains True in masked positions.
+
+ Returns:
+ Tensor: Output tensor (#time, batch, channels).
+
+ """
+
+ x = self.in_proj(x) # (time, batch, 2*channels)
+
+ x, s = x.chunk(2, dim=2)
+ s = self.balancer1(s)
+ s = self.sigmoid(s)
+ x = self.activation1(x) # identity.
+ x = x * s
+ x = self.activation2(x) # identity
+
+ # (time, batch, channels)
+
+ # exchange the temporal dimension and the feature dimension
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
+
+ if src_key_padding_mask is not None:
+ x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+
+ if (
+ not torch.jit.is_scripting()
+ and not torch.jit.is_tracing()
+ and chunk_size >= 0
+ ):
+ # Not support exporting a model for simulated streaming decoding
+ assert (
+ self.causal
+ ), "Must initialize model with causal=True if you use chunk_size"
+ x = self.depthwise_conv(x, chunk_size=chunk_size)
+ else:
+ x = self.depthwise_conv(x)
+
+ x = self.balancer2(x)
+ x = x.permute(2, 0, 1) # (time, batch, channels)
+
+ x = self.whiten(x) # (time, batch, channels)
+ x = self.out_proj(x) # (time, batch, channels)
+
+ return x
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ cache: Tensor,
+ src_key_padding_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """Compute convolution module in streaming forward mode.
+
+ Args:
+ x: Input tensor (#time, batch, channels).
+ cache: cached left context for depthwise_conv of shape
+ (#batch, channels, left_pad)
+ src_key_padding_mask: the mask for the src keys per batch (optional):
+ (batch, #time), contains True in masked positions.
+
+ Returns:
+ - Output tensor (#time, batch, channels).
+ - Updated cache (#batch, channels, left_pad)
+ """
+
+ x = self.in_proj(x) # (time, batch, 2*channels)
+
+ x, s = x.chunk(2, dim=2)
+ s = self.sigmoid(s)
+ x = x * s
+ # (time, batch, channels)
+
+ # exchange the temporal dimension and the feature dimension
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
+
+ if src_key_padding_mask is not None:
+ x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+
+ x, cache = self.depthwise_conv.streaming_forward(x, cache=cache)
+
+ x = x.permute(2, 0, 1) # (time, batch, channels)
+
+ x = self.out_proj(x) # (time, batch, channels)
+
+ return x, cache
+
+
+class ScalarMultiply(nn.Module):
+ def __init__(self, scale: float):
+ super().__init__()
+ self.scale = scale
+
+ def forward(self, x):
+ return x * self.scale
+
+
+def _test_zipformer_main(causal: bool = False):
+ batch_size = 5
+ seq_len = 20
+ # Just make sure the forward pass runs.
+
+ c = Zipformer2(
+ encoder_dim=(64, 96),
+ encoder_unmasked_dim=(48, 64),
+ num_heads=(4, 4),
+ causal=causal,
+ chunk_size=(4,) if causal else (-1,),
+ left_context_frames=(64,),
+ )
+ batch_size = 5
+ seq_len = 20
+ # Just make sure the forward pass runs.
+ f = c(
+ torch.randn(seq_len, batch_size, 64),
+ torch.full((batch_size,), seq_len, dtype=torch.int64),
+ )
+ f[0].sum().backward()
+ c.eval()
+ f = c(
+ torch.randn(seq_len, batch_size, 64),
+ torch.full((batch_size,), seq_len, dtype=torch.int64),
+ )
+ f # to remove flake8 warnings
+
+
+if __name__ == "__main__":
+ logging.getLogger().setLevel(logging.INFO)
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ _test_zipformer_main(False)
+ _test_zipformer_main(True)
diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py
index dd8949523..c1785a328 100755
--- a/egs/librispeech/ASR/zipformer_mmi/train.py
+++ b/egs/librispeech/ASR/zipformer_mmi/train.py
@@ -79,6 +79,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon, UniqLexicon
from icefall.mmi import LFMMILoss
@@ -816,9 +817,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/multi_zh-hans/ASR/whisper/train.py b/egs/multi_zh-hans/ASR/whisper/train.py
index 11a22eec1..b1b60077c 100644
--- a/egs/multi_zh-hans/ASR/whisper/train.py
+++ b/egs/multi_zh-hans/ASR/whisper/train.py
@@ -824,7 +824,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2**22
+ 512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py
index c1bbd2ee8..447ca122f 100755
--- a/egs/multi_zh-hans/ASR/zipformer/train.py
+++ b/egs/multi_zh-hans/ASR/zipformer/train.py
@@ -89,6 +89,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -1020,9 +1021,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/multi_zh_en/ASR/prepare.sh b/egs/multi_zh_en/ASR/prepare.sh
index 9f2be5a5c..a1530be29 100755
--- a/egs/multi_zh_en/ASR/prepare.sh
+++ b/egs/multi_zh_en/ASR/prepare.sh
@@ -115,9 +115,6 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
cat ./data/lang_bpe_500/transcript_words.txt \
>> $lang_dir/text_words_segmentation
-
- cat ./data/lang_char/text \
- >> $lang_dir/text
fi
cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \
diff --git a/egs/multi_zh_en/ASR/zipformer/decode_stream.py b/egs/multi_zh_en/ASR/zipformer/decode_stream.py
new file mode 120000
index 000000000..b8d8ddfc4
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/decode_stream.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decode_stream.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_decode.py b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py
deleted file mode 120000
index 13fd02a78..000000000
--- a/egs/multi_zh_en/ASR/zipformer/streaming_decode.py
+++ /dev/null
@@ -1 +0,0 @@
-../../../librispeech/ASR/zipformer/streaming_decode.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_decode.py b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py
new file mode 100755
index 000000000..7b9bd2d6c
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py
@@ -0,0 +1,869 @@
+#!/usr/bin/env python3
+# Copyright 2022-2024 Xiaomi Corporation (Authors: Wei Kang,
+# Fangjun Kuang,
+# Zengwei Yao,
+# Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+./zipformer/streaming_decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --causal 1 \
+ --chunk-size 32 \
+ --left-context-frames 256 \
+ --exp-dir ./zipformer/exp \
+ --decoding-method greedy_search \
+ --num-decode-streams 2000
+"""
+
+import argparse
+import logging
+import math
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import numpy as np
+import sentencepiece as spm
+import torch
+from asr_datamodule import AsrDataModule
+from decode_stream import DecodeStream
+from kaldifeat import Fbank, FbankOptions
+from lhotse import CutSet
+from lhotse.cut import Cut
+from multi_dataset import MultiDataset
+from streaming_beam_search import (
+ fast_beam_search_one_best,
+ greedy_search,
+ modified_beam_search,
+)
+from torch import Tensor, nn
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_model, get_params
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Supported decoding methods are:
+ greedy_search
+ modified_beam_search
+ fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--num_active_paths",
+ type=int,
+ default=4,
+ help="""An interger indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=32,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--num-decode-streams",
+ type=int,
+ default=2000,
+ help="The number of streams that can be decoded parallel.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_init_states(
+ model: nn.Module,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+) -> List[torch.Tensor]:
+ """
+ Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ states[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+ """
+ states = model.encoder.get_init_states(batch_size, device)
+
+ embed_states = model.encoder_embed.get_init_states(batch_size, device)
+ states.append(embed_states)
+
+ processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
+ states.append(processed_lens)
+
+ return states
+
+
+def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
+ """Stack list of zipformer states that correspond to separate utterances
+ into a single emformer state, so that it can be used as an input for
+ zipformer when those utterances are formed into a batch.
+
+ Args:
+ state_list:
+ Each element in state_list corresponding to the internal state
+ of the zipformer model for a single utterance. For element-n,
+ state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
+ state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
+ cached_val2, cached_conv1, cached_conv2).
+ state_list[n][-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ state_list[n][-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+
+ Note:
+ It is the inverse of :func:`unstack_states`.
+ """
+ batch_size = len(state_list)
+ assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
+ tot_num_layers = (len(state_list[0]) - 2) // 6
+
+ batch_states = []
+ for layer in range(tot_num_layers):
+ layer_offset = layer * 6
+ # cached_key: (left_context_len, batch_size, key_dim)
+ cached_key = torch.cat(
+ [state_list[i][layer_offset] for i in range(batch_size)], dim=1
+ )
+ # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
+ cached_nonlin_attn = torch.cat(
+ [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
+ )
+ # cached_val1: (left_context_len, batch_size, value_dim)
+ cached_val1 = torch.cat(
+ [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
+ )
+ # cached_val2: (left_context_len, batch_size, value_dim)
+ cached_val2 = torch.cat(
+ [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
+ )
+ # cached_conv1: (#batch, channels, left_pad)
+ cached_conv1 = torch.cat(
+ [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
+ )
+ # cached_conv2: (#batch, channels, left_pad)
+ cached_conv2 = torch.cat(
+ [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
+ )
+ batch_states += [
+ cached_key,
+ cached_nonlin_attn,
+ cached_val1,
+ cached_val2,
+ cached_conv1,
+ cached_conv2,
+ ]
+
+ cached_embed_left_pad = torch.cat(
+ [state_list[i][-2] for i in range(batch_size)], dim=0
+ )
+ batch_states.append(cached_embed_left_pad)
+
+ processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
+ batch_states.append(processed_lens)
+
+ return batch_states
+
+
+def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
+ """Unstack the zipformer state corresponding to a batch of utterances
+ into a list of states, where the i-th entry is the state from the i-th
+ utterance in the batch.
+
+ Note:
+ It is the inverse of :func:`stack_states`.
+
+ Args:
+ batch_states: A list of cached tensors of all encoder layers. For layer-i,
+ states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
+ cached_conv1, cached_conv2).
+ state_list[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+
+ Returns:
+ state_list: A list of list. Each element in state_list corresponding to the internal state
+ of the zipformer model for a single utterance.
+ """
+ assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
+ tot_num_layers = (len(batch_states) - 2) // 6
+
+ processed_lens = batch_states[-1]
+ batch_size = processed_lens.shape[0]
+
+ state_list = [[] for _ in range(batch_size)]
+
+ for layer in range(tot_num_layers):
+ layer_offset = layer * 6
+ # cached_key: (left_context_len, batch_size, key_dim)
+ cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
+ # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
+ cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_val1: (left_context_len, batch_size, value_dim)
+ cached_val1_list = batch_states[layer_offset + 2].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_val2: (left_context_len, batch_size, value_dim)
+ cached_val2_list = batch_states[layer_offset + 3].chunk(
+ chunks=batch_size, dim=1
+ )
+ # cached_conv1: (#batch, channels, left_pad)
+ cached_conv1_list = batch_states[layer_offset + 4].chunk(
+ chunks=batch_size, dim=0
+ )
+ # cached_conv2: (#batch, channels, left_pad)
+ cached_conv2_list = batch_states[layer_offset + 5].chunk(
+ chunks=batch_size, dim=0
+ )
+ for i in range(batch_size):
+ state_list[i] += [
+ cached_key_list[i],
+ cached_nonlin_attn_list[i],
+ cached_val1_list[i],
+ cached_val2_list[i],
+ cached_conv1_list[i],
+ cached_conv2_list[i],
+ ]
+
+ cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
+ for i in range(batch_size):
+ state_list[i].append(cached_embed_left_pad_list[i])
+
+ processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
+ for i in range(batch_size):
+ state_list[i].append(processed_lens_list[i])
+
+ return state_list
+
+
+def streaming_forward(
+ features: Tensor,
+ feature_lens: Tensor,
+ model: nn.Module,
+ states: List[Tensor],
+ chunk_size: int,
+ left_context_len: int,
+) -> Tuple[Tensor, Tensor, List[Tensor]]:
+ """
+ Returns encoder outputs, output lengths, and updated states.
+ """
+ cached_embed_left_pad = states[-2]
+ (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
+ x=features,
+ x_lens=feature_lens,
+ cached_left_pad=cached_embed_left_pad,
+ )
+ assert x.size(1) == chunk_size, (x.size(1), chunk_size)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+
+ # processed_mask is used to mask out initial states
+ processed_mask = torch.arange(left_context_len, device=x.device).expand(
+ x.size(0), left_context_len
+ )
+ processed_lens = states[-1] # (batch,)
+ # (batch, left_context_size)
+ processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
+ # Update processed lengths
+ new_processed_lens = processed_lens + x_lens
+
+ # (batch, left_context_size + chunk_size)
+ src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
+
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ encoder_states = states[:-2]
+ (
+ encoder_out,
+ encoder_out_lens,
+ new_encoder_states,
+ ) = model.encoder.streaming_forward(
+ x=x,
+ x_lens=x_lens,
+ states=encoder_states,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ new_states = new_encoder_states + [
+ new_cached_embed_left_pad,
+ new_processed_lens,
+ ]
+ return encoder_out, encoder_out_lens, new_states
+
+
+def decode_one_chunk(
+ params: AttributeDict,
+ model: nn.Module,
+ decode_streams: List[DecodeStream],
+) -> List[int]:
+ """Decode one chunk frames of features for each decode_streams and
+ return the indexes of finished streams in a List.
+
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ decode_streams:
+ A List of DecodeStream, each belonging to a utterance.
+ Returns:
+ Return a List containing which DecodeStreams are finished.
+ """
+ device = model.device
+ chunk_size = int(params.chunk_size)
+ left_context_len = int(params.left_context_frames)
+
+ features = []
+ feature_lens = []
+ states = []
+ processed_lens = [] # Used in fast-beam-search
+
+ for stream in decode_streams:
+ feat, feat_len = stream.get_feature_frames(chunk_size * 2)
+ features.append(feat)
+ feature_lens.append(feat_len)
+ states.append(stream.states)
+ processed_lens.append(stream.done_frames)
+
+ feature_lens = torch.tensor(feature_lens, device=device)
+ features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
+
+ # Make sure the length after encoder_embed is at least 1.
+ # The encoder_embed subsample features (T - 7) // 2
+ # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
+ tail_length = chunk_size * 2 + 7 + 2 * 3
+ if features.size(1) < tail_length:
+ pad_length = tail_length - features.size(1)
+ feature_lens += pad_length
+ features = torch.nn.functional.pad(
+ features,
+ (0, 0, 0, pad_length),
+ mode="constant",
+ value=LOG_EPS,
+ )
+
+ states = stack_states(states)
+
+ encoder_out, encoder_out_lens, new_states = streaming_forward(
+ features=features,
+ feature_lens=feature_lens,
+ model=model,
+ states=states,
+ chunk_size=chunk_size,
+ left_context_len=left_context_len,
+ )
+
+ encoder_out = model.joiner.encoder_proj(encoder_out)
+
+ if params.decoding_method == "greedy_search":
+ greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
+ elif params.decoding_method == "fast_beam_search":
+ processed_lens = torch.tensor(processed_lens, device=device)
+ processed_lens = processed_lens + encoder_out_lens
+ fast_beam_search_one_best(
+ model=model,
+ encoder_out=encoder_out,
+ processed_lens=processed_lens,
+ streams=decode_streams,
+ beam=params.beam,
+ max_states=params.max_states,
+ max_contexts=params.max_contexts,
+ )
+ elif params.decoding_method == "modified_beam_search":
+ modified_beam_search(
+ model=model,
+ streams=decode_streams,
+ encoder_out=encoder_out,
+ num_active_paths=params.num_active_paths,
+ )
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+
+ states = unstack_states(new_states)
+
+ finished_streams = []
+ for i in range(len(decode_streams)):
+ decode_streams[i].states = states[i]
+ decode_streams[i].done_frames += encoder_out_lens[i]
+ if decode_streams[i].done:
+ finished_streams.append(i)
+
+ return finished_streams
+
+
+def decode_dataset(
+ cuts: CutSet,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ cuts:
+ Lhotse Cutset containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ device = model.device
+
+ opts = FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = 16000
+ opts.mel_opts.num_bins = 80
+
+ log_interval = 100
+
+ decode_results = []
+ # Contain decode streams currently running.
+ decode_streams = []
+ for num, cut in enumerate(cuts):
+ # each utterance has a DecodeStream.
+ initial_states = get_init_states(model=model, batch_size=1, device=device)
+ decode_stream = DecodeStream(
+ params=params,
+ cut_id=cut.id,
+ initial_states=initial_states,
+ decoding_graph=decoding_graph,
+ device=device,
+ )
+
+ audio: np.ndarray = cut.load_audio()
+ # audio.shape: (1, num_samples)
+ assert len(audio.shape) == 2
+ assert audio.shape[0] == 1, "Should be single channel"
+ assert audio.dtype == np.float32, audio.dtype
+
+ # The trained model is using normalized samples
+ # - this is to avoid sending [-32k,+32k] signal in...
+ # - some lhotse AudioTransform classes can make the signal
+ # be out of range [-1, 1], hence the tolerance 10
+ assert (
+ np.abs(audio).max() <= 10
+ ), "Should be normalized to [-1, 1], 10 for tolerance..."
+
+ samples = torch.from_numpy(audio).squeeze(0)
+
+ fbank = Fbank(opts)
+ feature = fbank(samples.to(device))
+ decode_stream.set_features(feature, tail_pad_len=30)
+ decode_stream.ground_truth = cut.supervisions[0].text
+
+ decode_streams.append(decode_stream)
+
+ while len(decode_streams) >= params.num_decode_streams:
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ for i in sorted(finished_streams, reverse=True):
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ decode_streams[i].ground_truth.split(),
+ sp.decode(decode_streams[i].decoding_result()).split(),
+ )
+ )
+ del decode_streams[i]
+
+ if num % log_interval == 0:
+ logging.info(f"Cuts processed until now is {num}.")
+
+ # decode final chunks of last sequences
+ while len(decode_streams):
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ for i in sorted(finished_streams, reverse=True):
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ decode_streams[i].ground_truth.split(),
+ sp.decode(decode_streams[i].decoding_result()).split(),
+ )
+ )
+ del decode_streams[i]
+
+ if params.decoding_method == "greedy_search":
+ key = "greedy_search"
+ elif params.decoding_method == "fast_beam_search":
+ key = (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ )
+ elif params.decoding_method == "modified_beam_search":
+ key = f"num_active_paths_{params.num_active_paths}"
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+ return {key: decode_results}
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.res_dir = params.exp_dir / "streaming" / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ assert params.causal, params.causal
+ assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ # for fast_beam_search
+ if params.decoding_method == "fast_beam_search":
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ decoding_graph = None
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ multi_dataset = MultiDataset(args)
+
+ def remove_short_utt(c: Cut):
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ if T <= 0:
+ logging.warning(
+ f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}"
+ )
+ return T > 0
+
+ test_sets_cuts = multi_dataset.test_cuts()
+
+ test_sets = test_sets_cuts.keys()
+ test_cuts = [test_sets_cuts[k] for k in test_sets]
+ for test_set, test_cut in zip(test_sets, test_cuts):
+ logging.info(f"Decoding {test_set}")
+ test_cut = test_cut.filter(remove_short_utt)
+ results_dict = decode_dataset(
+ cuts=test_cut,
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py
index 310c8fe59..5dba584f7 100755
--- a/egs/multi_zh_en/ASR/zipformer/train.py
+++ b/egs/multi_zh_en/ASR/zipformer/train.py
@@ -89,6 +89,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -1042,9 +1043,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/spgispeech/ASR/zipformer/train.py b/egs/spgispeech/ASR/zipformer/train.py
index 1709a2845..ed66ca29b 100755
--- a/egs/spgispeech/ASR/zipformer/train.py
+++ b/egs/spgispeech/ASR/zipformer/train.py
@@ -89,6 +89,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -1020,9 +1021,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py
index aee3972cd..2108266ec 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py
@@ -78,6 +78,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -870,9 +871,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/egs/tedlium3/ASR/zipformer/train.py b/egs/tedlium3/ASR/zipformer/train.py
index 5ad01df27..14a44efb3 100755
--- a/egs/tedlium3/ASR/zipformer/train.py
+++ b/egs/tedlium3/ASR/zipformer/train.py
@@ -87,6 +87,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
@@ -985,9 +986,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/vctk/TTS/README.md b/egs/vctk/TTS/README.md
index c07516b77..c2703dbe2 100644
--- a/egs/vctk/TTS/README.md
+++ b/egs/vctk/TTS/README.md
@@ -10,7 +10,7 @@ The above information is from the [CSTR VCTK website](https://datashare.ed.ac.uk
This recipe provides a VITS model trained on the VCTK dataset.
-Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-vctk-vits-2023-12-05), note that this model was pretrained on the Edinburgh DataShare VCTK dataset.
+Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-vctk-vits-2024-03-18), note that this model was pretrained on the Edinburgh DataShare VCTK dataset.
For tutorial and more details, please refer to the [VITS documentation](https://k2-fsa.github.io/icefall/recipes/TTS/vctk/vits.html).
@@ -21,7 +21,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--world-size 4 \
--num-epochs 1000 \
--start-epoch 1 \
- --use-fp16 1 \
--exp-dir vits/exp \
--tokens data/tokens.txt
--max-duration 350
diff --git a/egs/vctk/TTS/local/prepare_token_file.py b/egs/vctk/TTS/local/prepare_token_file.py
deleted file mode 100755
index c6636c3ad..000000000
--- a/egs/vctk/TTS/local/prepare_token_file.py
+++ /dev/null
@@ -1,104 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-"""
-This file reads the texts in given manifest and generates the file that maps tokens to IDs.
-"""
-
-import argparse
-import logging
-from pathlib import Path
-from typing import Dict
-
-from lhotse import load_manifest
-
-
-def get_args():
- parser = argparse.ArgumentParser()
-
- parser.add_argument(
- "--manifest-file",
- type=Path,
- default=Path("data/spectrogram/vctk_cuts_all.jsonl.gz"),
- help="Path to the manifest file",
- )
-
- parser.add_argument(
- "--tokens",
- type=Path,
- default=Path("data/tokens.txt"),
- help="Path to the tokens",
- )
-
- return parser.parse_args()
-
-
-def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
- """Write a symbol to ID mapping to a file.
-
- Note:
- No need to implement `read_mapping` as it can be done
- through :func:`k2.SymbolTable.from_file`.
-
- Args:
- filename:
- Filename to save the mapping.
- sym2id:
- A dict mapping symbols to IDs.
- Returns:
- Return None.
- """
- with open(filename, "w", encoding="utf-8") as f:
- for sym, i in sym2id.items():
- f.write(f"{sym} {i}\n")
-
-
-def get_token2id(manifest_file: Path) -> Dict[str, int]:
- """Return a dict that maps token to IDs."""
- extra_tokens = [
- "", # 0 for blank
- "", # 1 for sos and eos symbols.
- "", # 2 for OOV
- ]
- all_tokens = set()
-
- cut_set = load_manifest(manifest_file)
-
- for cut in cut_set:
- # Each cut only contain one supervision
- assert len(cut.supervisions) == 1, len(cut.supervisions)
- for t in cut.tokens:
- all_tokens.add(t)
-
- all_tokens = extra_tokens + list(all_tokens)
-
- token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
- return token2id
-
-
-if __name__ == "__main__":
- formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-
- logging.basicConfig(format=formatter, level=logging.INFO)
-
- args = get_args()
- manifest_file = Path(args.manifest_file)
- out_file = Path(args.tokens)
-
- token2id = get_token2id(manifest_file)
- write_mapping(out_file, token2id)
diff --git a/egs/vctk/TTS/local/prepare_token_file.py b/egs/vctk/TTS/local/prepare_token_file.py
new file mode 120000
index 000000000..afc29a22b
--- /dev/null
+++ b/egs/vctk/TTS/local/prepare_token_file.py
@@ -0,0 +1 @@
+../../../ljspeech/TTS/local/prepare_token_file.py
\ No newline at end of file
diff --git a/egs/vctk/TTS/local/prepare_tokens_vctk.py b/egs/vctk/TTS/local/prepare_tokens_vctk.py
index 32e1c7dfa..0748eba5a 100755
--- a/egs/vctk/TTS/local/prepare_tokens_vctk.py
+++ b/egs/vctk/TTS/local/prepare_tokens_vctk.py
@@ -24,9 +24,9 @@ This file reads the texts in given manifest and save the new cuts with phoneme t
import logging
from pathlib import Path
-import g2p_en
import tacotron_cleaner.cleaners
from lhotse import CutSet, load_manifest
+from piper_phonemize import phonemize_espeak
from tqdm.auto import tqdm
@@ -37,17 +37,20 @@ def prepare_tokens_vctk():
partition = "all"
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
- g2p = g2p_en.G2p()
new_cuts = []
for cut in tqdm(cut_set):
# Each cut only contains one supervision
- assert len(cut.supervisions) == 1, len(cut.supervisions)
+ assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
text = cut.supervisions[0].text
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
- cut.tokens = g2p(text)
+ tokens_list = phonemize_espeak(text, "en-us")
+ tokens = []
+ for t in tokens_list:
+ tokens.extend(t)
+ cut.tokens = tokens
new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts)
diff --git a/egs/vctk/TTS/prepare.sh b/egs/vctk/TTS/prepare.sh
index 152c7b168..aab075312 100755
--- a/egs/vctk/TTS/prepare.sh
+++ b/egs/vctk/TTS/prepare.sh
@@ -78,6 +78,13 @@ fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for VCTK"
+ # We assume you have installed piper_phonemize and espnet_tts_frontend.
+ # If not, please install them with:
+ # - piper_phonemize:
+ # refer to https://github.com/rhasspy/piper-phonemize,
+ # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
+ # - espnet_tts_frontend:
+ # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/spectrogram/.vctk_with_token.done ]; then
./local/prepare_tokens_vctk.py
mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \
@@ -111,14 +118,15 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate token file"
- # We assume you have installed g2p_en and espnet_tts_frontend.
+ # We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
- # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
- # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
+ # - piper_phonemize:
+ # refer to https://github.com/rhasspy/piper-phonemize,
+ # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
+ # - espnet_tts_frontend:
+ # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/tokens.txt ]; then
- ./local/prepare_token_file.py \
- --manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \
- --tokens data/tokens.txt
+ ./local/prepare_token_file.py --tokens data/tokens.txt
fi
fi
diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py
index 80d155626..d00450f08 100755
--- a/egs/vctk/TTS/vits/export-onnx.py
+++ b/egs/vctk/TTS/vits/export-onnx.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
-# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
+# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao,
+# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -97,7 +98,7 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
- meta.value = value
+ meta.value = str(value)
onnx.save(model, filename)
@@ -160,6 +161,7 @@ def export_model_onnx(
model: nn.Module,
model_filename: str,
vocab_size: int,
+ n_speakers: int,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
@@ -212,10 +214,15 @@ def export_model_onnx(
)
meta_data = {
- "model_type": "VITS",
+ "model_type": "vits",
"version": "1",
"model_author": "k2-fsa",
- "comment": "VITS generator",
+ "comment": "icefall", # must be icefall for models from icefall
+ "language": "English",
+ "voice": "en-us", # Choose your language appropriately
+ "has_espeak": 1,
+ "n_speakers": n_speakers,
+ "sample_rate": 22050, # Must match the real sample rate
}
logging.info(f"meta_data: {meta_data}")
@@ -231,8 +238,7 @@ def main():
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
- params.blank_id = tokenizer.blank_id
- params.oov_id = tokenizer.oov_id
+ params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
with open(args.speakers) as f:
@@ -265,6 +271,7 @@ def main():
model,
model_filename,
params.vocab_size,
+ params.num_spks,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")
diff --git a/egs/vctk/TTS/vits/infer.py b/egs/vctk/TTS/vits/infer.py
index 06c25f02e..2e1abdefb 100755
--- a/egs/vctk/TTS/vits/infer.py
+++ b/egs/vctk/TTS/vits/infer.py
@@ -135,14 +135,16 @@ def infer_dataset(
batch_size = len(batch["tokens"])
tokens = batch["tokens"]
- tokens = tokenizer.tokens_to_token_ids(tokens)
+ tokens = tokenizer.tokens_to_token_ids(
+ tokens, intersperse_blank=True, add_sos=True, add_eos=True
+ )
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# tensor of shape (B, T)
- tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
+ tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
speakers = (
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
.int()
@@ -214,8 +216,7 @@ def main():
device = torch.device("cuda", 0)
tokenizer = Tokenizer(params.tokens)
- params.blank_id = tokenizer.blank_id
- params.oov_id = tokenizer.oov_id
+ params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
# we need cut ids to display recognition results.
diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py
index d85c0a27b..ae6587338 100755
--- a/egs/vctk/TTS/vits/test_onnx.py
+++ b/egs/vctk/TTS/vits/test_onnx.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
-# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
+# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao,
+# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -122,7 +123,9 @@ def main():
model = OnnxModel(args.model_filename)
text = "I went there to see the land, the people and how their system works, end quote."
- tokens = tokenizer.texts_to_token_ids([text])
+ tokens = tokenizer.texts_to_token_ids(
+ [text], intersperse_blank=True, add_sos=True, add_eos=True
+ )
tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
speaker = torch.tensor([1], dtype=torch.int64) # (1, )
diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py
index 56f167a17..55bd69327 100755
--- a/egs/vctk/TTS/vits/train.py
+++ b/egs/vctk/TTS/vits/train.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
-# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao,
+# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -342,14 +343,16 @@ def prepare_input(
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device)
)
- tokens = tokenizer.tokens_to_token_ids(tokens)
+ tokens = tokenizer.tokens_to_token_ids(
+ tokens, intersperse_blank=True, add_sos=True, add_eos=True
+ )
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
- tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
+ tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
@@ -812,8 +815,7 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
- params.blank_id = tokenizer.blank_id
- params.oov_id = tokenizer.oov_id
+ params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
vctk = VctkTtsDataModule(args)
diff --git a/egs/vctk/TTS/vits/tts_datamodule.py b/egs/vctk/TTS/vits/tts_datamodule.py
index 52fc5179f..6c785d8c3 100644
--- a/egs/vctk/TTS/vits/tts_datamodule.py
+++ b/egs/vctk/TTS/vits/tts_datamodule.py
@@ -1,6 +1,7 @@
# Copyright 2021 Piotr Żelasko
-# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
-# Zengwei Yao)
+# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
+# Zengwei Yao,
+# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
diff --git a/egs/wenetspeech/ASR/whisper/train.py b/egs/wenetspeech/ASR/whisper/train.py
index 4b7c1ca42..6ff500ab9 100644
--- a/egs/wenetspeech/ASR/whisper/train.py
+++ b/egs/wenetspeech/ASR/whisper/train.py
@@ -803,7 +803,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2**22
+ 512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py
index b1557dedb..3d3762916 100755
--- a/egs/wenetspeech/ASR/zipformer/train.py
+++ b/egs/wenetspeech/ASR/zipformer/train.py
@@ -86,6 +86,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import (
@@ -985,9 +986,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py
index 76df7e8d5..3ad16fd11 100755
--- a/egs/wenetspeech/KWS/zipformer/finetune.py
+++ b/egs/wenetspeech/KWS/zipformer/finetune.py
@@ -111,6 +111,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import (
@@ -525,9 +526,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py
index 05acbd6a9..eddec7303 100755
--- a/egs/wenetspeech/KWS/zipformer/train.py
+++ b/egs/wenetspeech/KWS/zipformer/train.py
@@ -88,6 +88,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import (
@@ -1042,9 +1043,7 @@ def train_one_epoch(
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
@@ -1188,7 +1187,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
- 2**22
+ 512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py
index 8c53972fd..d24c27326 100755
--- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py
+++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py
@@ -81,6 +81,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
+from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@@ -854,9 +855,7 @@ def train_one_epoch(
if cur_grad_scale < 0.01:
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
+ raise_grad_scale_is_too_small_error(cur_grad_scale)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
diff --git a/icefall/err.py b/icefall/err.py
new file mode 100644
index 000000000..065e2a53d
--- /dev/null
+++ b/icefall/err.py
@@ -0,0 +1,47 @@
+# Copyright 2024 Xiaomi Corp. (authors: Zengrui Jin,)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+def raise_grad_scale_is_too_small_error(cur_grad_scale: float):
+ raise RuntimeError(
+ f"""
+ grad_scale is too small, exiting: {cur_grad_scale}
+
+ ========================= NOTE =========================
+ If you see this error, it means that the gradient scale is too small.
+
+ The default base_lr is 0.045 / 0.05 (depends on which recipe you are
+ using), this is an empirical value obtained mostly using 4 * 32GB V100
+ GPUs with a max_duration of approx. 1,000.
+ The proper value of base_lr may vary depending on the number of GPUs
+ and the value of max-duration you are using.
+
+ To fix this issue, you may need to adjust the value of base_lr accordingly.
+
+ We would suggest you to decrease the value of base_lr by 0.005 (e.g.,
+ from 0.045 to 0.04), and try again. If the error still exists, you may
+ repeat the process until base_lr hits 0.02. (Note that this will lead to
+ certain loss of performance, but it should work. You can compensate this by
+ increasing the num_epochs.)
+
+ If the error still exists, you could try to seek help by raising an issue,
+ with a detailed description of (a) your computational resources, (b) the
+ base_lr and (c) the max_duration you are using, (d) detailed configuration
+ of your model.
+
+ ========================================================
+ """
+ )
diff --git a/icefall/utils.py b/icefall/utils.py
index 31f9801d9..2cb2edf93 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -1081,9 +1081,11 @@ def write_surt_error_stats(
f"{cut_id}:\t"
+ " ".join(
(
- ref_word
- if ref_word == hyp_word
- else f"({ref_word}->{hyp_word})"
+ (
+ ref_word
+ if ref_word == hyp_word
+ else f"({ref_word}->{hyp_word})"
+ )
for ref_word, hyp_word in ali
)
),