diff --git a/.flake8 b/.flake8
index 410cb5482..cf276d0ba 100644
--- a/.flake8
+++ b/.flake8
@@ -15,7 +15,7 @@ per-file-ignores =
egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203
egs/librispeech/ASR/zipformer/*.py: E501, E203
egs/librispeech/ASR/RESULTS.md: E999,
-
+ egs/ljspeech/TTS/vits/*.py: E501, E203
# invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605
diff --git a/.github/scripts/run-multi-zh_hans-zipformer.sh b/.github/scripts/run-multi-corpora-zipformer.sh
similarity index 65%
rename from .github/scripts/run-multi-zh_hans-zipformer.sh
rename to .github/scripts/run-multi-corpora-zipformer.sh
index dd32a94f8..90f859f43 100755
--- a/.github/scripts/run-multi-zh_hans-zipformer.sh
+++ b/.github/scripts/run-multi-corpora-zipformer.sh
@@ -51,6 +51,8 @@ for method in modified_beam_search fast_beam_search; do
$repo/test_wavs/DEV_T0000000002.wav
done
+rm -rf $repo
+
log "==== Test icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 ===="
repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/
@@ -92,4 +94,42 @@ for method in modified_beam_search fast_beam_search; do
$repo/test_wavs/DEV_T0000000000.wav \
$repo/test_wavs/DEV_T0000000001.wav \
$repo/test_wavs/DEV_T0000000002.wav
-done
\ No newline at end of file
+done
+
+rm -rf $repo
+
+cd ../../../egs/multi_zh_en/ASR
+log "==== Test icefall-asr-zipformer-multi-zh-en-2023-11-22 ===="
+repo_url=https://huggingface.co/zrjin/icefall-asr-zipformer-multi-zh-en-2023-11-22/
+
+log "Downloading pre-trained model from $repo_url"
+git lfs install
+git clone $repo_url
+repo=$(basename $repo_url)
+
+log "Display test files"
+tree $repo/
+ls -lh $repo/test_wavs/*.wav
+
+./zipformer/pretrained.py \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bbpe_2000/bbpe.model \
+ --method greedy_search \
+$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_29.wav \
+$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_55.wav \
+$repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_75.wav
+
+for method in modified_beam_search fast_beam_search; do
+ log "$method"
+
+ ./zipformer/pretrained.py \
+ --method $method \
+ --beam-size 4 \
+ --checkpoint $repo/exp/pretrained.pt \
+ --bpe-model $repo/data/lang_bbpe_2000/bbpe.model \
+ $repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_29.wav \
+ $repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_55.wav \
+ $repo/test_wavs/_1634_210_2577_1_1525157964032_3712259_75.wav
+done
+
+rm -rf $repo
diff --git a/.github/workflows/run-multi-zh_hans-zipformer.yml b/.github/workflows/run-multi-corpora-zipformer.yml
similarity index 91%
rename from .github/workflows/run-multi-zh_hans-zipformer.yml
rename to .github/workflows/run-multi-corpora-zipformer.yml
index 72c0775a7..38f7eb908 100644
--- a/.github/workflows/run-multi-zh_hans-zipformer.yml
+++ b/.github/workflows/run-multi-corpora-zipformer.yml
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-name: run-multi-zh_hans-zipformer
+name: run-multi-corpora-zipformer
on:
push:
@@ -24,12 +24,12 @@ on:
types: [labeled]
concurrency:
- group: run_multi-zh_hans_zipformer-${{ github.ref }}
+ group: run_multi-corpora_zipformer-${{ github.ref }}
cancel-in-progress: true
jobs:
- run_multi-zh_hans_zipformer:
- if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' || github.event.label.name == 'zipformer'
+ run_multi-corpora_zipformer:
+ if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'multi-zh_hans' || github.event.label.name == 'zipformer' || github.event.label.name == 'multi-corpora'
runs-on: ${{ matrix.os }}
strategy:
matrix:
@@ -81,4 +81,4 @@ jobs:
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
- .github/scripts/run-multi-zh_hans-zipformer.sh
+ .github/scripts/run-multi-corpora-zipformer.sh
diff --git a/docs/source/recipes/TTS/index.rst b/docs/source/recipes/TTS/index.rst
new file mode 100644
index 000000000..aa891c072
--- /dev/null
+++ b/docs/source/recipes/TTS/index.rst
@@ -0,0 +1,7 @@
+TTS
+======
+
+.. toctree::
+ :maxdepth: 2
+
+ ljspeech/vits
diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst
new file mode 100644
index 000000000..385fd3c70
--- /dev/null
+++ b/docs/source/recipes/TTS/ljspeech/vits.rst
@@ -0,0 +1,113 @@
+VITS
+===============
+
+This tutorial shows you how to train an VITS model
+with the `LJSpeech `_ dataset.
+
+.. note::
+
+ The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_
+
+
+Data preparation
+----------------
+
+.. code-block:: bash
+
+ $ cd egs/ljspeech/TTS
+ $ ./prepare.sh
+
+To run stage 1 to stage 5, use
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage 1 --stop_stage 5
+
+
+Build Monotonic Alignment Search
+--------------------------------
+
+.. code-block:: bash
+
+ $ cd vits/monotonic_align
+ $ python setup.py build_ext --inplace
+ $ cd ../../
+
+
+Training
+--------
+
+.. code-block:: bash
+
+ $ export CUDA_VISIBLE_DEVICES="0,1,2,3"
+ $ ./vits/train.py \
+ --world-size 4 \
+ --num-epochs 1000 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+ --max-duration 500
+
+.. note::
+
+ You can adjust the hyper-parameters to control the size of the VITS model and
+ the training configurations. For more details, please run ``./vits/train.py --help``.
+
+.. note::
+
+ The training can take a long time (usually a couple of days).
+
+Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``.
+
+
+Inference
+---------
+
+The inference part uses checkpoints saved by the training part, so you have to run the
+training part first. It will save the ground-truth and generated wavs to the directory
+``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``.
+
+.. code-block:: bash
+
+ $ export CUDA_VISIBLE_DEVICES="0"
+ $ ./vits/infer.py \
+ --epoch 1000 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+ --max-duration 500
+
+.. note::
+
+ For more details, please run ``./vits/infer.py --help``.
+
+
+Export models
+-------------
+
+Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
+``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
+
+.. code-block:: bash
+
+ $ ./vits/export-onnx.py \
+ --epoch 1000 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+
+You can test the exported ONNX model with:
+
+.. code-block:: bash
+
+ $ ./vits/test_onnx.py \
+ --model-filename vits/exp/vits-epoch-1000.onnx \
+ --tokens data/tokens.txt
+
+
+Download pretrained models
+--------------------------
+
+If you don't want to train from scratch, you can download the pretrained models
+by visiting the following link:
+
+ - ``_
diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst
index 7265e1cf6..8df61f0d0 100644
--- a/docs/source/recipes/index.rst
+++ b/docs/source/recipes/index.rst
@@ -2,7 +2,7 @@ Recipes
=======
This page contains various recipes in ``icefall``.
-Currently, only speech recognition recipes are provided.
+Currently, we provide recipes for speech recognition, language model, and speech synthesis.
We may add recipes for other tasks as well in the future.
@@ -16,3 +16,4 @@ We may add recipes for other tasks as well in the future.
Non-streaming-ASR/index
Streaming-ASR/index
RNN-LM/index
+ TTS/index
diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh
index d36dc5ed3..9f73a2073 100755
--- a/egs/aishell/ASR/prepare.sh
+++ b/egs/aishell/ASR/prepare.sh
@@ -261,10 +261,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
fi
if [ ! -f $lang_char_dir/HLG.fst ]; then
- lang_phone_dir=data/lang_phone
./local/prepare_lang_fst.py \
- --lang-dir $lang_phone_dir \
- --ngram-G ./data/lm/G_3_gram.fst.txt
+ --lang-dir $lang_char_dir \
+ --ngram-G ./data/lm/G_3_gram_char.fst.txt
fi
fi
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py
index be58c4e43..696eea906 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/decode.py
@@ -641,7 +641,7 @@ def main():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
- context_graph.build(contexts)
+ context_graph.build([(c, 0.0) for c in contexts])
else:
context_graph = None
else:
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py
similarity index 99%
rename from egs/aishell/ASR/pruned_transducer_stateless7/train2.py
rename to egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py
index 057af297f..6027273b2 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/train2.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py
@@ -1234,6 +1234,7 @@ def scan_pessimistic_batches_for_oom(
def main():
+ raise RuntimeError("Please don't use this file directly!")
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
index 2a9fc57d5..39d988cd0 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7/export-onnx.py
@@ -56,7 +56,7 @@ import torch.nn as nn
from decoder2 import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
-from train2 import add_model_arguments, get_params, get_transducer_model
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from zipformer import Zipformer
from icefall.checkpoint import (
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
index f5ae836fd..99110d6b6 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
@@ -686,7 +686,7 @@ def main():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
- context_graph.build(contexts)
+ context_graph.build([(c, 0.0) for c in contexts])
else:
context_graph = None
else:
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
similarity index 99%
rename from egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py
rename to egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index 88eb34104..3c13c19c6 100755
--- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -1233,6 +1233,7 @@ def scan_pessimistic_batches_for_oom(
def main():
+ raise RuntimeError("Please don't use this file directly!")
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md
index 991875aaa..6c20bab2c 100644
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/README.md
@@ -4,6 +4,6 @@ See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer
[./emformer.py](./emformer.py) and [./train.py](./train.py)
are basically the same as
-[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py).
-The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py)
+[./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py).
+The only purpose of [./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py)
is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn).
diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
similarity index 99%
rename from egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py
rename to egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index c09c9537c..61a3f27db 100755
--- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train2.py
+++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -1237,6 +1237,7 @@ def scan_pessimistic_batches_for_oom(
def main():
+ raise RuntimeError("Please don't use this file directly!")
parser = get_parser()
CommonVoiceAsrDataModule.add_arguments(parser)
args = parser.parse_args()
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
similarity index 99%
rename from egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py
rename to egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index 4c866ddd8..acde72d80 100755
--- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train2.py
+++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -1274,6 +1274,7 @@ def scan_pessimistic_batches_for_oom(
def main():
+ raise RuntimeError("Please don't use this file directly!")
parser = get_parser()
CSJAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
index ebdb596a5..b210430c6 100755
--- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
+++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
@@ -72,7 +72,7 @@ from pathlib import Path
import torch
from scaling_converter import convert_scaled_to_non_scaled
from tokenizer import Tokenizer
-from train2 import add_model_arguments, get_params, get_transducer_model
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/libriheavy/ASR/README.md b/egs/libriheavy/ASR/README.md
new file mode 100644
index 000000000..2498d017f
--- /dev/null
+++ b/egs/libriheavy/ASR/README.md
@@ -0,0 +1,6 @@
+# Libriheavy: a 50,000 hours ASR corpus with punctuation casing and context
+
+Libriheavy is a labeled version of [Librilight](https://arxiv.org/pdf/1912.07875.pdf). Please refer to our repository [k2-fsa/libriheavy](https://github.com/k2-fsa/libriheavy) for more details. We also have a paper: *Libriheavy: a 50,000 hours ASR corpus with punctuation casing and context*, [Preprint available on arxiv](https://arxiv.org/abs/2309.08105).
+
+
+See [RESULTS](./RESULTS.md) for the results for icefall recipes.
diff --git a/egs/libriheavy/ASR/RESULTS.md b/egs/libriheavy/ASR/RESULTS.md
index 4fbedad98..513bbf72e 100644
--- a/egs/libriheavy/ASR/RESULTS.md
+++ b/egs/libriheavy/ASR/RESULTS.md
@@ -1,6 +1,116 @@
-## Results
+# Results
-### Zipformer PromptASR (zipformer + PromptASR + BERT text encoder)
+## zipformer (zipformer + pruned stateless transducer)
+
+See for more details.
+
+[zipformer](./zipformer)
+
+### Non-streaming
+
+#### Training on normalized text, i.e. Upper case without punctuation
+
+##### normal-scaled model, number of model parameters: 65805511, i.e., 65.81 M
+
+You can find a pretrained model, training logs at:
+
+
+Note: The repository above contains three models trained on different subset of libriheavy exp(large set), exp_medium_subset(medium set),
+exp_small_subset(small set).
+
+Results of models:
+
+| training set | decoding method | librispeech clean | librispeech other | libriheavy clean | libriheavy other | comment |
+|---------------|---------------------|-------------------|-------------------|------------------|------------------|--------------------|
+| small | greedy search | 4.19 | 9.99 | 4.75 | 10.25 |--epoch 90 --avg 20 |
+| small | modified beam search| 4.05 | 9.89 | 4.68 | 10.01 |--epoch 90 --avg 20 |
+| medium | greedy search | 2.39 | 4.85 | 2.90 | 6.6 |--epoch 60 --avg 20 |
+| medium | modified beam search| 2.35 | 4.82 | 2.90 | 6.57 |--epoch 60 --avg 20 |
+| large | greedy search | 1.67 | 3.32 | 2.24 | 5.61 |--epoch 16 --avg 3 |
+| large | modified beam search| 1.62 | 3.36 | 2.20 | 5.57 |--epoch 16 --avg 3 |
+
+The training command is:
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+python ./zipformer/train.py \
+ --world-size 4 \
+ --master-port 12365 \
+ --exp-dir zipformer/exp \
+ --num-epochs 60 \ # 16 for large; 90 for small
+ --lr-hours 15000 \ # 20000 for large; 5000 for small
+ --use-fp16 1 \
+ --start-epoch 1 \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --max-duration 1000 \
+ --subset medium
+```
+
+The decoding command is:
+```bash
+export CUDA_VISIBLE_DEVICES="0"
+for m in greedy_search modified_beam_search; do
+ ./zipformer/decode.py \
+ --epoch 16 \
+ --avg 3 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000 \
+ --causal 0 \
+ --decoding-method $m
+done
+```
+
+#### Training on full formatted text, i.e. with casing and punctuation
+
+##### normal-scaled model, number of model parameters: 66074067 , i.e., 66M
+
+You can find a pretrained model, training logs at:
+
+
+Note: The repository above contains three models trained on different subset of libriheavy exp(large set), exp_medium_subset(medium set),
+exp_small_subset(small set).
+
+Results of models:
+
+| training set | decoding method | libriheavy clean (WER) | libriheavy other (WER) | libriheavy clean (CER) | libriheavy other (CER) | comment |
+|---------------|---------------------|-------------------|-------------------|------------------|------------------|--------------------|
+| small | modified beam search| 13.04 | 19.54 | 4.51 | 7.90 |--epoch 88 --avg 41 |
+| medium | modified beam search| 9.84 | 13.39 | 3.02 | 5.10 |--epoch 50 --avg 15 |
+| large | modified beam search| 7.76 | 11.32 | 2.41 | 4.22 |--epoch 16 --avg 2 |
+
+The training command is:
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+python ./zipformer/train.py \
+ --world-size 4 \
+ --master-port 12365 \
+ --exp-dir zipformer/exp \
+ --num-epochs 60 \ # 16 for large; 90 for small
+ --lr-hours 15000 \ # 20000 for large; 10000 for small
+ --use-fp16 1 \
+ --train-with-punctuation 1 \
+ --start-epoch 1 \
+ --bpe-model data/lang_punc_bpe_756/bpe.model \
+ --max-duration 1000 \
+ --subset medium
+```
+
+The decoding command is:
+```bash
+export CUDA_VISIBLE_DEVICES="0"
+for m in greedy_search modified_beam_search; do
+ ./zipformer/decode.py \
+ --epoch 16 \
+ --avg 3 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000 \
+ --causal 0 \
+ --decoding-method $m
+done
+```
+
+## Zipformer PromptASR (zipformer + PromptASR + BERT text encoder)
#### [zipformer_prompt_asr](./zipformer_prompt_asr)
diff --git a/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py b/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py
new file mode 100755
index 000000000..010531db2
--- /dev/null
+++ b/egs/libriheavy/ASR/local/compute_fbank_libriheavy.py
@@ -0,0 +1,242 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the Libriheavy dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Optional
+
+import torch
+from lhotse import (
+ CutSet,
+ Fbank,
+ FbankConfig,
+ KaldifeatFbank,
+ KaldifeatFbankConfig,
+ LilcomChunkyWriter,
+)
+
+from icefall.utils import get_executor, str2bool
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--manifest-dir",
+ type=str,
+ help="""The source directory that contains raw manifests.
+ """,
+ default="data/manifests",
+ )
+
+ parser.add_argument(
+ "--fbank-dir",
+ type=str,
+ help="""Fbank output dir
+ """,
+ default="data/fbank",
+ )
+
+ parser.add_argument(
+ "--subset",
+ type=str,
+ help="""Dataset parts to compute fbank. If None, we will use all""",
+ )
+
+ parser.add_argument(
+ "--num-workers",
+ type=int,
+ default=20,
+ help="Number of dataloading workers used for reading the audio.",
+ )
+
+ parser.add_argument(
+ "--batch-duration",
+ type=float,
+ default=600.0,
+ help="The maximum number of audio seconds in a batch."
+ "Determines batch size dynamically.",
+ )
+
+ parser.add_argument(
+ "--perturb-speed",
+ type=str2bool,
+ default=False,
+ help="Whether to use speed perturbation.",
+ )
+
+ parser.add_argument(
+ "--use-splits",
+ type=str2bool,
+ default=False,
+ help="Whether to compute fbank on splits.",
+ )
+
+ parser.add_argument(
+ "--num-splits",
+ type=int,
+ help="""The number of splits of the medium and large subset.
+ Only needed when --use-splits is true.""",
+ )
+
+ parser.add_argument(
+ "--start",
+ type=int,
+ default=0,
+ help="""Process pieces starting from this number (inclusive).
+ Only needed when --use-splits is true.""",
+ )
+
+ parser.add_argument(
+ "--stop",
+ type=int,
+ default=-1,
+ help="""Stop processing pieces until this number (exclusive).
+ Only needed when --use-splits is true.""",
+ )
+
+ return parser.parse_args()
+
+
+def compute_fbank_libriheavy(args):
+ src_dir = Path(args.manifest_dir)
+ output_dir = Path(args.fbank_dir)
+ num_jobs = min(15, os.cpu_count())
+ num_mel_bins = 80
+ subset = args.subset
+
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+ with get_executor() as ex: # Initialize the executor only once.
+ output_cuts_path = output_dir / f"libriheavy_cuts_{subset}.jsonl.gz"
+ if output_cuts_path.exists():
+ logging.info(f"{output_cuts_path} exists - skipping")
+ return
+
+ input_cuts_path = src_dir / f"libriheavy_cuts_{subset}.jsonl.gz"
+ assert input_cuts_path.exists(), f"{input_cuts_path} does not exist!"
+ logging.info(f"Loading {input_cuts_path}")
+ cut_set = CutSet.from_file(input_cuts_path)
+
+ logging.info("Computing features")
+
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/libriheavy_feats_{subset}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 80,
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+
+ logging.info(f"Saving to {output_cuts_path}")
+ cut_set.to_file(output_cuts_path)
+
+
+def compute_fbank_libriheavy_splits(args):
+ num_splits = args.num_splits
+ subset = args.subset
+ src_dir = f"{args.manifest_dir}/libriheavy_{subset}_split"
+ src_dir = Path(src_dir)
+ output_dir = f"{args.fbank_dir}/libriheavy_{subset}_split"
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ start = args.start
+ stop = args.stop
+ if stop < start:
+ stop = num_splits
+
+ stop = min(stop, num_splits)
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+ extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
+ logging.info(f"device: {device}")
+
+ num_digits = 8 # num_digits is fixed by lhotse split-lazy
+ for i in range(start, stop):
+ idx = f"{i + 1}".zfill(num_digits)
+ logging.info(f"Processing {idx}/{num_splits}")
+
+ cuts_path = output_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz"
+ if cuts_path.is_file():
+ logging.info(f"{cuts_path} exists - skipping")
+ continue
+
+ raw_cuts_path = src_dir / f"libriheavy_cuts_{subset}.{idx}.jsonl.gz"
+ if not raw_cuts_path.is_file():
+ logging.info(f"{raw_cuts_path} does not exist - skipping it")
+ continue
+
+ logging.info(f"Loading {raw_cuts_path}")
+ cut_set = CutSet.from_file(raw_cuts_path)
+
+ logging.info("Computing features")
+ if (output_dir / f"libriheavy_feats_{subset}_{idx}.lca").exists():
+ logging.info(f"Removing {output_dir}/libriheavy_feats_{subset}_{idx}.lca")
+ os.remove(output_dir / f"libriheavy_feats_{subset}_{idx}.lca")
+
+ cut_set = cut_set.compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=f"{output_dir}/libriheavy_feats_{subset}_{idx}",
+ num_workers=args.num_workers,
+ batch_duration=args.batch_duration,
+ overwrite=True,
+ )
+
+ logging.info("About to split cuts into smaller chunks.")
+ cut_set = cut_set.trim_to_supervisions(
+ keep_overlapping=False, min_duration=None
+ )
+
+ logging.info(f"Saving to {cuts_path}")
+ cut_set.to_file(cuts_path)
+ logging.info(f"Saved to {cuts_path}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ args = get_args()
+ logging.info(vars(args))
+
+ if args.use_splits:
+ assert args.num_splits is not None, "Please provide num_splits"
+ compute_fbank_libriheavy_splits(args)
+ else:
+ compute_fbank_libriheavy(args)
diff --git a/egs/libriheavy/ASR/local/compute_fbank_musan.py b/egs/libriheavy/ASR/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/libriheavy/ASR/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/local/norm_text.py b/egs/libriheavy/ASR/local/norm_text.py
new file mode 100755
index 000000000..c2fc0d92d
--- /dev/null
+++ b/egs/libriheavy/ASR/local/norm_text.py
@@ -0,0 +1,58 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import codecs
+import sys
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--text",
+ type=str,
+ help="""Path to the input text.
+ """,
+ )
+ return parser.parse_args()
+
+
+def remove_punc_to_upper(text: str) -> str:
+ text = text.replace("‘", "'")
+ text = text.replace("’", "'")
+ tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
+ s_list = [x.upper() if x in tokens else " " for x in text]
+ s = " ".join("".join(s_list).split()).strip()
+ return s
+
+
+def main():
+ args = get_args()
+ if args.text:
+ f = codecs.open(args.text, encoding="utf-8")
+ else:
+ f = codecs.getreader("utf-8")(sys.stdin.buffer)
+
+ sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer)
+ line = f.readline()
+ while line:
+ print(remove_punc_to_upper(line))
+ line = f.readline()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/libriheavy/ASR/local/prepare_manifest.py b/egs/libriheavy/ASR/local/prepare_manifest.py
new file mode 100755
index 000000000..42f392cae
--- /dev/null
+++ b/egs/libriheavy/ASR/local/prepare_manifest.py
@@ -0,0 +1,47 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gzip
+import json
+import sys
+from pathlib import Path
+
+
+def simple_cleanup(text: str) -> str:
+ table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
+ text = text.translate(table)
+ return text.strip()
+
+
+# Assign text of the supervisions and remove unnecessary entries.
+def main():
+ assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR"
+ fname = Path(sys.argv[1]).name
+ oname = Path(sys.argv[2]) / fname
+ with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
+ for line in fin:
+ cut = json.loads(line)
+ cut["supervisions"][0]["text"] = simple_cleanup(
+ cut["supervisions"][0]["custom"]["texts"][0]
+ )
+ del cut["supervisions"][0]["custom"]
+ del cut["custom"]
+ fout.write((json.dumps(cut) + "\n").encode())
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/libriheavy/ASR/local/train_bpe_model.py b/egs/libriheavy/ASR/local/train_bpe_model.py
new file mode 100755
index 000000000..19caf43ab
--- /dev/null
+++ b/egs/libriheavy/ASR/local/train_bpe_model.py
@@ -0,0 +1,113 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# You can install sentencepiece via:
+#
+# pip install sentencepiece
+#
+# Due to an issue reported in
+# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
+#
+# Please install a version >=0.1.96
+
+import argparse
+import shutil
+from pathlib import Path
+
+import sentencepiece as spm
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Input and output directory.
+ The generated bpe.model is saved to this directory.
+ """,
+ )
+
+ parser.add_argument(
+ "--byte-fallback",
+ action="store_true",
+ help="""Whether to enable byte_fallback when training bpe.""",
+ )
+
+ parser.add_argument(
+ "--character-coverage",
+ type=float,
+ default=1.0,
+ help="Character coverage in vocabulary.",
+ )
+
+ parser.add_argument(
+ "--transcript",
+ type=str,
+ help="Training transcript.",
+ )
+
+ parser.add_argument(
+ "--vocab-size",
+ type=int,
+ help="Vocabulary size for BPE training",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+ vocab_size = args.vocab_size
+ lang_dir = Path(args.lang_dir)
+
+ model_type = "unigram"
+
+ model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
+ train_text = args.transcript
+ input_sentence_size = 100000000
+
+ user_defined_symbols = ["", ""]
+ unk_id = len(user_defined_symbols)
+ # Note: unk_id is fixed to 2.
+ # If you change it, you should also change other
+ # places that are using it.
+
+ model_file = Path(model_prefix + ".model")
+ if not model_file.is_file():
+ spm.SentencePieceTrainer.train(
+ input=train_text,
+ vocab_size=vocab_size,
+ model_type=model_type,
+ model_prefix=model_prefix,
+ input_sentence_size=input_sentence_size,
+ character_coverage=args.character_coverage,
+ user_defined_symbols=user_defined_symbols,
+ byte_fallback=args.byte_fallback,
+ unk_id=unk_id,
+ bos_id=-1,
+ eos_id=-1,
+ )
+ else:
+ print(f"{model_file} exists - skipping")
+ return
+
+ shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/libriheavy/ASR/prepare.sh b/egs/libriheavy/ASR/prepare.sh
new file mode 100755
index 000000000..af7e3c5b0
--- /dev/null
+++ b/egs/libriheavy/ASR/prepare.sh
@@ -0,0 +1,314 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+nj=15
+stage=-1
+stop_stage=100
+export CUDA_VISIBLE_DEVICES=""
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/librilight
+# You can find small, medium, large, etc. inside it.
+#
+# - $dl_dir/libriheavy
+# You can find libriheavy_cuts_small.jsonl.gz, libriheavy_cuts_medium.jsonl.gz, etc. inside it.
+#
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# vocab size for sentence piece models.
+# It will generate data/lang_bpe_xxx,
+# data/lang_bpe_yyy if the array contains xxx, yyy
+vocab_sizes=(
+ # 5000
+ # 2000
+ # 1000
+ 500
+)
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+fbank_dir=data/fbank
+manifests_dir=data/manifests
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
+ log "Stage -1: Download audio data."
+ # If you have pre-downloaded it to /path/to/librilight,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/librilight $dl_dir/librilight
+ #
+ mkdir -p $dl_dir/librilight
+ for subset in small medium large; do
+ log "Downloading ${subset} subset."
+ if [ ! -d $dl_dir/librilight/${subset} ]; then
+ wget -P $dl_dir/librilight -c https://dl.fbaipublicfiles.com/librilight/data/${subset}.tar
+ tar xf $dl_dir/librilight/${subset}.tar -C $dl_dir/librilight
+ else
+ log "Skipping download, ${subset} subset exists."
+ fi
+ done
+fi
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download manifests from huggingface."
+
+ # If you have pre-downloaded it to /path/to/libriheavy,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/libriheavy $dl_dir/libriheavy
+ #
+ mkdir -p $dl_dir/libriheavy
+ for subset in small medium large dev test_clean test_other; do
+ if [ ! -e $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz ]; then
+ log "Downloading ${subset} subset."
+ wget -P $dl_dir/libriheavy -c https://huggingface.co/datasets/pkufool/libriheavy/resolve/main/libriheavy_cuts_${subset}.jsonl.gz
+ else
+ log "Skipping download, ${subset} subset exists."
+ fi
+ done
+
+ # If you have pre-downloaded it to /path/to/musan,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/musan $dl_dir/
+ #
+ if [ ! -d $dl_dir/musan ]; then
+ lhotse download musan $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Download manifests from modelscope"
+ mkdir -p $dl_dir/libriheavy
+ if [ ! -e $dl_dir/libriheavy/libriheavy_cuts_small.jsonl.gz ]; then
+ cd $dl_dir/libriheavy
+ GIT_LFS_SKIP_SMUDGE=1 git clone https://www.modelscope.cn/datasets/pkufool/Libriheavy.git
+ cd Libriheavy
+ git lfs pull --exclude "raw/*"
+ mv *.jsonl.gz ../
+ cd ..
+ rm -rf Libriheavy
+ cd ../../
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Prepare musan manifest"
+ # We assume that you have downloaded the musan corpus
+ # to $dl_dir/musan
+ mkdir -p $manifests_dir
+ if [ ! -e $manifests_dir/.musan.done ]; then
+ lhotse prepare musan $dl_dir/musan $manifests_dir
+ touch $manifests_dir/.musan.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare Libriheavy manifests"
+ mkdir -p $manifests_dir
+ for subset in small medium large dev test_clean test_other; do
+ if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
+ log "Prepare manifest for subset : ${subset}"
+ ./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir
+ fi
+ done
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Compute fbank for musan"
+ mkdir -p $fbank_dir
+ if [ ! -e $fbank_dir/.musan.done ]; then
+ ./local/compute_fbank_musan.py
+ touch $fbank_dir/.musan.done
+ fi
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Compute fbank for small subset and validation subsets"
+ for subset in test_clean test_other dev small; do
+ log "Computing $subset subset."
+ if [ ! -e $fbank_dir/.libriheavy.${subset}.done ]; then
+ ./local/compute_fbank_libriheavy.py \
+ --manifest-dir ${manifests_dir} \
+ --subset ${subset} \
+ --fbank-dir $fbank_dir \
+ --num-workers $nj
+ fi
+ done
+fi
+
+num_per_split=8000
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Split medium and large subsets."
+ for subset in medium large; do
+ log "Spliting subset : $subset"
+ split_dir=$manifests_dir/libriheavy_${subset}_split
+ mkdir -p $split_dir
+ if [ ! -e $split_dir/.split_completed ]; then
+ lhotse split-lazy $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz $split_dir $num_per_split
+ touch $split_dir/.split_completed
+ fi
+ done
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+ log "Stage 7: Compute fbank for medium and large subsets"
+ mkdir -p $fbank_dir
+ chunk_size=20
+ for subset in medium large; do
+ if [ $subset == "large" ]; then
+ chunk_size=200
+ fi
+ num_splits=$(find $manifests_dir/libriheavy_${subset}_split -name "libriheavy_cuts_${subset}.*.jsonl.gz" | wc -l)
+ if [ ! -e $fbank_dir/.libriheavy.${subset}.done ]; then
+ for i in $(seq 0 1 6); do
+ start=$(( i * $chunk_size ))
+ end=$(( (i+1) * $chunk_size ))
+ ./local/compute_fbank_libriheavy.py \
+ --manifest-dir ${manifests_dir} \
+ --use-splits 1 \
+ --subset ${subset} \
+ --fbank-dir $fbank_dir \
+ --num-splits $num_splits \
+ --num-workers $nj \
+ --start $start \
+ --stop $end &
+ done
+ wait
+ touch $fbank_dir/.libriheavy.${subset}.done
+ fi
+ done
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+ log "Stage 8: Combine features for medium and large subsets."
+ for subset in medium large; do
+ log "Combining $subset subset."
+ if [ ! -f $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
+ pieces=$(find $fbank_dir/libriheavy_${subset}_split -name "libriheavy_cuts_${subset}.*.jsonl.gz")
+ lhotse combine $pieces $fbank_dir/libriheavy_cuts_${subset}.jsonl.gz
+ fi
+ done
+fi
+
+if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
+ log "Stage 9: Train BPE model for normalized text"
+
+ if [ ! -f data/texts ]; then
+ gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
+ | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
+ | ./local/norm_text.py > data/texts
+ fi
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}
+ mkdir -p $lang_dir
+
+ cp data/texts $lang_dir/text
+
+ if [ ! -f $lang_dir/bpe.model ]; then
+ ./local/train_bpe_model.py \
+ --lang-dir $lang_dir \
+ --vocab-size $vocab_size \
+ --transcript $lang_dir/text
+ fi
+ done
+fi
+
+
+if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
+ log "Stage 10: Train BPE model for unnormalized text"
+ if [ ! -f data/punc_texts ]; then
+ gunzip -c $manifests_dir/libriheavy_cuts_medium.jsonl.gz \
+ | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' > data/punc_texts
+ fi
+ for vocab_size in ${vocab_sizes[@]}; do
+ new_vacab_size = $(($vocab_size + 256))
+ lang_dir=data/lang_punc_bpe_${new_vocab_size}
+ mkdir -p $lang_dir
+
+ cp data/punc_texts $lang_dir/text
+
+ if [ ! -f $lang_dir/bpe.model ]; then
+ ./local/train_bpe_model.py \
+ --lang-dir $lang_dir \
+ --byte-fallback \
+ --vocab-size ${new_vocab_size} \
+ --byte-fallback \
+ --character-coverage 0.99 \
+ --transcript $lang_dir/text
+ fi
+ done
+fi
+
+if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
+ log "Stage 11: Prepare language model for normalized text"
+
+ for subset in small medium large; do
+ if [ ! -f $manifests_dir/texts_${subset} ]; then
+ gunzip -c $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz \
+ | jq '.supervisions[].text' | sed 's/"//;s/\\//g;s/"$//' \
+ | ./local/norm_text.py > $manifests_dir/texts_${subset}
+ fi
+ done
+
+ mkdir -p data/lm
+ if [ ! -f data/lm/text ]; then
+ cat $manifests_dir/texts_small $manifests_dir/texts_medium $manifests_dir/texts_large > data/lm/text
+ fi
+
+ (echo ' 0'; echo '!SIL 1'; echo ' 2'; echo ' 3';) \
+ > data/lm/words.txt
+
+ cat data/lm/text | sed 's/ /\n/g' | sort -u | sed '/^$/d' \
+ | awk '{print $1" "NR+3}' >> data/lm/words.txt
+
+ num_lines=$(< data/lm/words.txt wc -l)
+ (echo "#0 $num_lines"; echo " $(($num_lines + 1))"; echo " $(($num_lines + 2))";) \
+ >> data/lm/words.txt
+
+ # Train LM on transcripts
+ if [ ! -f data/lm/3-gram.unpruned.arpa ]; then
+ python3 ./shared/make_kn_lm.py \
+ -ngram-order 3 \
+ -text data/lm/text \
+ -lm data/lm/3-gram.unpruned.arpa
+ fi
+
+ # We assume you have install kaldilm, if not, please install
+ # it using: pip install kaldilm
+ if [ ! -f data/lm/G_3_gram_char.fst.txt ]; then
+ # It is used in building HLG
+ python3 -m kaldilm \
+ --read-symbol-table=data/lm/words.txt \
+ --disambig-symbol='#0' \
+ --max-order=3 \
+ data/lm/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt
+ fi
+fi
+
diff --git a/egs/libriheavy/ASR/zipformer/asr_datamodule.py b/egs/libriheavy/ASR/zipformer/asr_datamodule.py
new file mode 100644
index 000000000..df761c1b8
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/asr_datamodule.py
@@ -0,0 +1,443 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
+ AudioSamples,
+ OnTheFlyFeatures,
+)
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class LibriHeavyAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--subset",
+ type=str,
+ default="S",
+ help="""The subset to be used. Should be S, M or L. Note: S subset
+ includes libriheavy_cuts_small.jsonl.gz, M subset includes
+ libriheavy_cuts_small.jsonl.gz and libriheavy_cuts_medium.jsonl.gz,
+ L subset includes libriheavy_cuts_small.jsonl.gz,
+ libriheavy_cuts_medium.jsonl.gz and libriheavy_cuts_large.jsonl.gz.
+ """,
+ )
+
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+ transforms.append(
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ input_strategy=eval(self.args.input_strategy)(),
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_small_cuts(self) -> CutSet:
+ logging.info("About to get small subset cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_medium_cuts(self) -> CutSet:
+ logging.info("About to get medium subset cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_large_cuts(self) -> CutSet:
+ logging.info("About to get large subset cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_cuts(self) -> CutSet:
+ logging.info("About to get dev cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_clean_cuts(self) -> CutSet:
+ logging.info("About to get the test-clean cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libriheavy_cuts_test_clean.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_other_cuts(self) -> CutSet:
+ logging.info("About to get the test-other cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libriheavy_cuts_test_other.jsonl.gz"
+ )
diff --git a/egs/libriheavy/ASR/zipformer/beam_search.py b/egs/libriheavy/ASR/zipformer/beam_search.py
new file mode 120000
index 000000000..e24eca39f
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/decode.py b/egs/libriheavy/ASR/zipformer/decode.py
new file mode 100644
index 000000000..1928e2635
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/decode.py
@@ -0,0 +1,794 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+"""
+
+
+import argparse
+import logging
+import math
+import warnings
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriHeavyAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from lhotse.cut import Cut
+from text_normalization import remove_punc_to_upper
+from train import add_model_arguments, get_model, get_params
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ make_pad_mask,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--train-with-punctuation",
+ type=str2bool,
+ default=False,
+ help="""Set to True, if the model was trained on texts with casing
+ and punctuation.""",
+ )
+
+ parser.add_argument(
+ "--post-normalization",
+ type=str2bool,
+ default=False,
+ help="""Upper case and remove all chars except ' and -
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(supervisions["text"]),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp.split())
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp).split())
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ this_batch = []
+ if params.post_normalization and params.train_with_punctuation:
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = remove_punc_to_upper(ref_text).split()
+ hyp_words = remove_punc_to_upper(" ".join(hyp_words)).split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[f"{name}_norm"].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(
+ f"batch {batch_str}, cuts processed until now is {num_cuts}"
+ )
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LibriHeavyAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ if "fast_beam_search" in params.decoding_method:
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ libriheavy = LibriHeavyAsrDataModule(args)
+
+ def normalize_text(c: Cut):
+ text = remove_punc_to_upper(c.supervisions[0].text)
+ c.supervisions[0].text = text
+ return c
+
+ test_clean_cuts = libriheavy.test_clean_cuts()
+ test_other_cuts = libriheavy.test_other_cuts()
+
+ if not params.train_with_punctuation:
+ test_clean_cuts = test_clean_cuts.map(normalize_text)
+ test_other_cuts = test_other_cuts.map(normalize_text)
+
+ test_clean_dl = libriheavy.test_dataloaders(test_clean_cuts)
+ test_other_dl = libriheavy.test_dataloaders(test_other_cuts)
+
+ test_sets = ["test-clean", "test-other"]
+ test_dl = [test_clean_dl, test_other_dl]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/libriheavy/ASR/zipformer/decoder.py b/egs/libriheavy/ASR/zipformer/decoder.py
new file mode 120000
index 000000000..5a8018680
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/encoder_interface.py b/egs/libriheavy/ASR/zipformer/encoder_interface.py
new file mode 120000
index 000000000..c2eaca671
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/encoder_interface.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/export-onnx.py b/egs/libriheavy/ASR/zipformer/export-onnx.py
new file mode 120000
index 000000000..70a15683c
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/export-onnx.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/export.py b/egs/libriheavy/ASR/zipformer/export.py
new file mode 120000
index 000000000..dfc1bec08
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/jit_pretrained.py b/egs/libriheavy/ASR/zipformer/jit_pretrained.py
new file mode 120000
index 000000000..25108391f
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/jit_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/jit_pretrained.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/joiner.py b/egs/libriheavy/ASR/zipformer/joiner.py
new file mode 120000
index 000000000..5b8a36332
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/model.py b/egs/libriheavy/ASR/zipformer/model.py
new file mode 120000
index 000000000..cd7e07d72
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/model.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/onnx_decode.py b/egs/libriheavy/ASR/zipformer/onnx_decode.py
new file mode 120000
index 000000000..0573b88c5
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/onnx_decode.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_decode.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/onnx_pretrained.py b/egs/libriheavy/ASR/zipformer/onnx_pretrained.py
new file mode 120000
index 000000000..8f32f4ee7
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/onnx_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_pretrained.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/optim.py b/egs/libriheavy/ASR/zipformer/optim.py
new file mode 120000
index 000000000..5eaa3cffd
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/pretrained.py b/egs/libriheavy/ASR/zipformer/pretrained.py
new file mode 120000
index 000000000..0bd71dde4
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/pretrained.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/scaling.py b/egs/libriheavy/ASR/zipformer/scaling.py
new file mode 120000
index 000000000..6f398f431
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/scaling_coverter.py b/egs/libriheavy/ASR/zipformer/scaling_coverter.py
new file mode 120000
index 000000000..b0ecee05e
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/scaling_coverter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling_converter.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/subsampling.py b/egs/libriheavy/ASR/zipformer/subsampling.py
new file mode 120000
index 000000000..01ae9002c
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/subsampling.py
\ No newline at end of file
diff --git a/egs/libriheavy/ASR/zipformer/text_normalization.py b/egs/libriheavy/ASR/zipformer/text_normalization.py
new file mode 100644
index 000000000..92590769c
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/text_normalization.py
@@ -0,0 +1,50 @@
+from num2words import num2words
+
+
+def remove_punc_to_upper(text: str) -> str:
+ text = text.replace("‘", "'")
+ text = text.replace("’", "'")
+ tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
+ s_list = [x.upper() if x in tokens else " " for x in text]
+ s = " ".join("".join(s_list).split()).strip()
+ return s
+
+
+def word_normalization(word: str) -> str:
+ # 1. Use full word for some abbreviation
+ # 2. Convert digits to english words
+ # 3. Convert ordinal number to english words
+ if word == "MRS":
+ return "MISSUS"
+ if word == "MR":
+ return "MISTER"
+ if word == "ST":
+ return "SAINT"
+ if word == "ECT":
+ return "ET CETERA"
+
+ if word[-2:] in ("ST", "ND", "RD", "TH") and word[:-2].isnumeric(): # e.g 9TH, 6TH
+ word = num2words(word[:-2], to="ordinal")
+ word = word.replace("-", " ")
+
+ if word.isnumeric():
+ num = int(word)
+ if num > 1500 and num < 2030:
+ word = num2words(word, to="year")
+ else:
+ word = num2words(word)
+ word = word.replace("-", " ")
+ return word.upper()
+
+
+def text_normalization(text: str) -> str:
+ text = text.upper()
+ return " ".join([word_normalization(x) for x in text.split()])
+
+
+if __name__ == "__main__":
+ assert remove_punc_to_upper("I like this 《book>") == "I LIKE THIS BOOK"
+ assert (
+ text_normalization("Hello Mrs st 21st world 3rd she 99th MR")
+ == "HELLO MISSUS SAINT TWENTY FIRST WORLD THIRD SHE NINETY NINTH MISTER"
+ )
diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py
new file mode 100644
index 000000000..c97da4a11
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/train.py
@@ -0,0 +1,1415 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey,
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+# For non-streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --full-libri 1 \
+ --max-duration 1000
+
+# For streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --causal 1 \
+ --full-libri 1 \
+ --max-duration 1000
+
+It supports training with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import random
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriHeavyAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from text_normalization import remove_punc_to_upper
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-hours",
+ type=float,
+ default=30000,
+ help="""Number of hours that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--train-with-punctuation",
+ type=str2bool,
+ default=False,
+ help="If True, the training text will include casing and punctuation.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+ # Use the number of hours of speech to adjust the learning rate
+ scheduler.step_epoch(
+ params.batch_idx_train * params.max_duration * params.world_size / 3600
+ )
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_hours)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ def normalize_text(c: Cut):
+ text = remove_punc_to_upper(c.supervisions[0].text)
+ c.supervisions[0].text = text
+ return c
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 2.0 or c.duration > 30.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ libriheavy = LibriHeavyAsrDataModule(args)
+
+ train_cuts = libriheavy.train_small_cuts()
+ if params.subset == "M" or params.subset == "L":
+ train_cuts += libriheavy.train_medium_cuts()
+ if params.subset == "L":
+ train_cuts += libriheavy.train_large_cuts()
+
+ if not params.train_with_punctuation:
+ train_cuts = train_cuts.map(normalize_text)
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = libriheavy.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = libriheavy.dev_cuts()
+
+ if not params.train_with_punctuation:
+ valid_cuts = valid_cuts.map(normalize_text)
+
+ valid_dl = libriheavy.valid_dataloaders(valid_cuts)
+
+ # if not params.print_diagnostics:
+ # scan_pessimistic_batches_for_oom(
+ # model=model,
+ # train_dl=train_dl,
+ # optimizer=optimizer,
+ # sp=sp,
+ # params=params,
+ # )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ LibriHeavyAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/libriheavy/ASR/zipformer/zipformer.py b/egs/libriheavy/ASR/zipformer/zipformer.py
new file mode 120000
index 000000000..23011dda7
--- /dev/null
+++ b/egs/libriheavy/ASR/zipformer/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/zipformer.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py
similarity index 99%
rename from egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py
rename to egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py
index 420dc1065..d614f0914 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py
@@ -1099,6 +1099,7 @@ def scan_pessimistic_batches_for_oom(
def main():
+ raise RuntimeError("Please don't use this file directly!")
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py
index 85dbd4661..953f95c45 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py
@@ -39,8 +39,8 @@ from pathlib import Path
import k2
import torch
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from scaling_converter import convert_scaled_to_non_scaled
-from train2 import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py
index ab046557f..1e59e0858 100755
--- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py
+++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py
@@ -61,7 +61,7 @@ import torch.nn as nn
from decoder import Decoder
from emformer import Emformer
from scaling_converter import convert_scaled_to_non_scaled
-from train2 import add_model_arguments, get_params, get_transducer_model
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
index 524366068..5195a4ef6 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
@@ -927,9 +927,9 @@ def main():
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
- contexts.append(line.strip())
+ contexts.append((sp.encode(line.strip()), 0.0))
context_graph = ContextGraph(params.context_score)
- context_graph.build(sp.encode(contexts))
+ context_graph.build(contexts)
else:
context_graph = None
else:
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md
index d3691e647..0f3c63e75 100644
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/README.md
@@ -4,7 +4,7 @@ See https://k2-fsa.github.io/icefall/recipes/Streaming-ASR/librispeech/zipformer
[./emformer.py](./emformer.py) and [./train.py](./train.py)
are basically the same as
-[./emformer2.py](./emformer2.py) and [./train2.py](./train2.py).
-The only purpose of [./emformer2.py](./emformer2.py) and [./train2.py](./train2.py)
+[./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py).
+The only purpose of [./emformer2.py](./emformer2.py) and [./do_not_use_it_directly.py](./do_not_use_it_directly.py)
is for exporting to [sherpa-ncnn](https://github.com/k2-fsa/sherpa-ncnn).
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
similarity index 99%
rename from egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py
rename to egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
index aa6c0668a..cd26db6f3 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train2.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
@@ -1234,6 +1234,7 @@ def scan_pessimistic_batches_for_oom(
def main():
+ raise RuntimeError("Please don't use this file directly!")
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
index 07de57a86..a7d06a5dd 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
@@ -68,8 +68,8 @@ from pathlib import Path
import k2
import torch
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from scaling_converter import convert_scaled_to_non_scaled
-from train2 import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
index 9a6b31268..8f2178b1d 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
@@ -66,8 +66,8 @@ from pathlib import Path
import k2
import torch
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from scaling_converter import convert_scaled_to_non_scaled
-from train2 import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/do_not_use_it_directly.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/do_not_use_it_directly.py
new file mode 120000
index 000000000..beeffaa03
--- /dev/null
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/do_not_use_it_directly.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7_streaming/do_not_use_it_directly.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py
index 9a6b31268..8f2178b1d 100755
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py
+++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py
@@ -66,8 +66,8 @@ from pathlib import Path
import k2
import torch
+from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from scaling_converter import convert_scaled_to_non_scaled
-from train2 import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py
deleted file mode 120000
index 3c3280b68..000000000
--- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train2.py
+++ /dev/null
@@ -1 +0,0 @@
-../pruned_transducer_stateless7_streaming/train2.py
\ No newline at end of file
diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py
index 3531d657f..339e253e6 100755
--- a/egs/librispeech/ASR/zipformer/decode.py
+++ b/egs/librispeech/ASR/zipformer/decode.py
@@ -1001,9 +1001,9 @@ def main():
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
- contexts.append(line.strip())
+ contexts.append((sp.encode(line.strip()), 0.0))
context_graph = ContextGraph(params.context_score)
- context_graph.build(sp.encode(contexts))
+ context_graph.build(contexts)
else:
context_graph = None
else:
diff --git a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py
new file mode 100755
index 000000000..97c9008fc
--- /dev/null
+++ b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py
@@ -0,0 +1,106 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the LJSpeech dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated spectrogram features are saved in data/spectrogram.
+"""
+
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import (
+ CutSet,
+ LilcomChunkyWriter,
+ Spectrogram,
+ SpectrogramConfig,
+ load_manifest,
+)
+from lhotse.audio import RecordingSet
+from lhotse.supervision import SupervisionSet
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_spectrogram_ljspeech():
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/spectrogram")
+ num_jobs = min(4, os.cpu_count())
+
+ sampling_rate = 22050
+ frame_length = 1024 / sampling_rate # (in second)
+ frame_shift = 256 / sampling_rate # (in second)
+ use_fft_mag = True
+
+ prefix = "ljspeech"
+ suffix = "jsonl.gz"
+ partition = "all"
+
+ recordings = load_manifest(
+ src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet
+ )
+ supervisions = load_manifest(
+ src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
+ )
+
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ use_fft_mag=use_fft_mag,
+ )
+ extractor = Spectrogram(config)
+
+ with get_executor() as ex: # Initialize the executor only once.
+ cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{cuts_filename} already exists - skipping.")
+ return
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=recordings, supervisions=supervisions
+ )
+
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 80,
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+ cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ compute_spectrogram_ljspeech()
diff --git a/egs/ljspeech/TTS/local/display_manifest_statistics.py b/egs/ljspeech/TTS/local/display_manifest_statistics.py
new file mode 100755
index 000000000..93f0044f0
--- /dev/null
+++ b/egs/ljspeech/TTS/local/display_manifest_statistics.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+
+See the function `remove_short_and_long_utt()` in vits/train.py
+for usage.
+"""
+
+
+from lhotse import load_manifest_lazy
+
+
+def main():
+ path = "./data/spectrogram/ljspeech_cuts_all.jsonl.gz"
+ cuts = load_manifest_lazy(path)
+ cuts.describe()
+
+
+if __name__ == "__main__":
+ main()
+
+"""
+Cut statistics:
+ ╒═══════════════════════════╤══════════╕
+ │ Cuts count: │ 13100 │
+ ├───────────────────────────┼──────────┤
+ │ Total duration (hh:mm:ss) │ 23:55:18 │
+ ├───────────────────────────┼──────────┤
+ │ mean │ 6.6 │
+ ├───────────────────────────┼──────────┤
+ │ std │ 2.2 │
+ ├───────────────────────────┼──────────┤
+ │ min │ 1.1 │
+ ├───────────────────────────┼──────────┤
+ │ 25% │ 5.0 │
+ ├───────────────────────────┼──────────┤
+ │ 50% │ 6.8 │
+ ├───────────────────────────┼──────────┤
+ │ 75% │ 8.4 │
+ ├───────────────────────────┼──────────┤
+ │ 99% │ 10.0 │
+ ├───────────────────────────┼──────────┤
+ │ 99.5% │ 10.1 │
+ ├───────────────────────────┼──────────┤
+ │ 99.9% │ 10.1 │
+ ├───────────────────────────┼──────────┤
+ │ max │ 10.1 │
+ ├───────────────────────────┼──────────┤
+ │ Recordings available: │ 13100 │
+ ├───────────────────────────┼──────────┤
+ │ Features available: │ 13100 │
+ ├───────────────────────────┼──────────┤
+ │ Supervisions available: │ 13100 │
+ ╘═══════════════════════════╧══════════╛
+"""
diff --git a/egs/ljspeech/TTS/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py
new file mode 100755
index 000000000..df976804a
--- /dev/null
+++ b/egs/ljspeech/TTS/local/prepare_token_file.py
@@ -0,0 +1,104 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file reads the texts in given manifest and generates the file that maps tokens to IDs.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Dict
+
+from lhotse import load_manifest
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--manifest-file",
+ type=Path,
+ default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"),
+ help="Path to the manifest file",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=Path,
+ default=Path("data/tokens.txt"),
+ help="Path to the tokens",
+ )
+
+ return parser.parse_args()
+
+
+def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
+ """Write a symbol to ID mapping to a file.
+
+ Note:
+ No need to implement `read_mapping` as it can be done
+ through :func:`k2.SymbolTable.from_file`.
+
+ Args:
+ filename:
+ Filename to save the mapping.
+ sym2id:
+ A dict mapping symbols to IDs.
+ Returns:
+ Return None.
+ """
+ with open(filename, "w", encoding="utf-8") as f:
+ for sym, i in sym2id.items():
+ f.write(f"{sym} {i}\n")
+
+
+def get_token2id(manifest_file: Path) -> Dict[str, int]:
+ """Return a dict that maps token to IDs."""
+ extra_tokens = [
+ "", # 0 for blank
+ "", # 1 for sos and eos symbols.
+ "", # 2 for OOV
+ ]
+ all_tokens = set()
+
+ cut_set = load_manifest(manifest_file)
+
+ for cut in cut_set:
+ # Each cut only contain one supervision
+ assert len(cut.supervisions) == 1, len(cut.supervisions)
+ for t in cut.tokens:
+ all_tokens.add(t)
+
+ all_tokens = extra_tokens + list(all_tokens)
+
+ token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
+ return token2id
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+ manifest_file = Path(args.manifest_file)
+ out_file = Path(args.tokens)
+
+ token2id = get_token2id(manifest_file)
+ write_mapping(out_file, token2id)
diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
new file mode 100755
index 000000000..fcd0137a0
--- /dev/null
+++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
@@ -0,0 +1,59 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file reads the texts in given manifest and save the new cuts with phoneme tokens.
+"""
+
+import logging
+from pathlib import Path
+
+import g2p_en
+import tacotron_cleaner.cleaners
+from lhotse import CutSet, load_manifest
+
+
+def prepare_tokens_ljspeech():
+ output_dir = Path("data/spectrogram")
+ prefix = "ljspeech"
+ suffix = "jsonl.gz"
+ partition = "all"
+
+ cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
+ g2p = g2p_en.G2p()
+
+ new_cuts = []
+ for cut in cut_set:
+ # Each cut only contains one supervision
+ assert len(cut.supervisions) == 1, len(cut.supervisions)
+ text = cut.supervisions[0].normalized_text
+ # Text normalization
+ text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
+ # Convert to phonemes
+ cut.tokens = g2p(text)
+ new_cuts.append(cut)
+
+ new_cut_set = CutSet.from_cuts(new_cuts)
+ new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ prepare_tokens_ljspeech()
diff --git a/egs/ljspeech/TTS/local/validate_manifest.py b/egs/ljspeech/TTS/local/validate_manifest.py
new file mode 100755
index 000000000..68159ae03
--- /dev/null
+++ b/egs/ljspeech/TTS/local/validate_manifest.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script checks the following assumptions of the generated manifest:
+
+- Single supervision per cut
+
+We will add more checks later if needed.
+
+Usage example:
+
+ python3 ./local/validate_manifest.py \
+ ./data/spectrogram/ljspeech_cuts_all.jsonl.gz
+
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+from lhotse import CutSet, load_manifest_lazy
+from lhotse.dataset.speech_synthesis import validate_for_tts
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "manifest",
+ type=Path,
+ help="Path to the manifest file",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ manifest = args.manifest
+ logging.info(f"Validating {manifest}")
+
+ assert manifest.is_file(), f"{manifest} does not exist"
+ cut_set = load_manifest_lazy(manifest)
+ assert isinstance(cut_set, CutSet), type(cut_set)
+
+ validate_for_tts(cut_set)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh
new file mode 100755
index 000000000..8ee40896e
--- /dev/null
+++ b/egs/ljspeech/TTS/prepare.sh
@@ -0,0 +1,117 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+nj=1
+stage=-1
+stop_stage=100
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ # The directory $dl_dir/LJSpeech-1.1 will contain:
+ # - wavs, which contains the audio files
+ # - metadata.csv, which provides the transcript text for each audio clip
+
+ # If you have pre-downloaded it to /path/to/LJSpeech-1.1, you can create a symlink
+ #
+ # ln -sfv /path/to/LJSpeech-1.1 $dl_dir/LJSpeech-1.1
+ #
+ if [ ! -d $dl_dir/LJSpeech-1.1 ]; then
+ lhotse download ljspeech $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare LJSpeech manifest"
+ # We assume that you have downloaded the LJSpeech corpus
+ # to $dl_dir/LJSpeech
+ mkdir -p data/manifests
+ if [ ! -e data/manifests/.ljspeech.done ]; then
+ lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests
+ touch data/manifests/.ljspeech.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Compute spectrogram for LJSpeech"
+ mkdir -p data/spectrogram
+ if [ ! -e data/spectrogram/.ljspeech.done ]; then
+ ./local/compute_spectrogram_ljspeech.py
+ touch data/spectrogram/.ljspeech.done
+ fi
+
+ if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then
+ log "Validating data/spectrogram for LJSpeech"
+ python3 ./local/validate_manifest.py \
+ data/spectrogram/ljspeech_cuts_all.jsonl.gz
+ touch data/spectrogram/.ljspeech-validated.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare phoneme tokens for LJSpeech"
+ if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
+ ./local/prepare_tokens_ljspeech.py
+ mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \
+ data/spectrogram/ljspeech_cuts_all.jsonl.gz
+ touch data/spectrogram/.ljspeech_with_token.done
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Split the LJSpeech cuts into train, valid and test sets"
+ if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
+ lhotse subset --last 600 \
+ data/spectrogram/ljspeech_cuts_all.jsonl.gz \
+ data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
+ lhotse subset --first 100 \
+ data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
+ data/spectrogram/ljspeech_cuts_valid.jsonl.gz
+ lhotse subset --last 500 \
+ data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
+ data/spectrogram/ljspeech_cuts_test.jsonl.gz
+
+ rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
+
+ n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 ))
+ lhotse subset --first $n \
+ data/spectrogram/ljspeech_cuts_all.jsonl.gz \
+ data/spectrogram/ljspeech_cuts_train.jsonl.gz
+ touch data/spectrogram/.ljspeech_split.done
+ fi
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Generate token file"
+ # We assume you have installed g2p_en and espnet_tts_frontend.
+ # If not, please install them with:
+ # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
+ # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
+ if [ ! -e data/tokens.txt ]; then
+ ./local/prepare_token_file.py \
+ --manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \
+ --tokens data/tokens.txt
+ fi
+fi
+
+
diff --git a/egs/ljspeech/TTS/shared/parse_options.sh b/egs/ljspeech/TTS/shared/parse_options.sh
new file mode 120000
index 000000000..e4665e7de
--- /dev/null
+++ b/egs/ljspeech/TTS/shared/parse_options.sh
@@ -0,0 +1 @@
+../../../librispeech/ASR/shared/parse_options.sh
\ No newline at end of file
diff --git a/egs/ljspeech/TTS/vits/README.md b/egs/ljspeech/TTS/vits/README.md
new file mode 100644
index 000000000..1141326b9
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/README.md
@@ -0,0 +1,3 @@
+See https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html for detailed tutorials.
+
+Training logs, Tensorboard logs, and checkpoints are uploaded to https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29.
diff --git a/egs/ljspeech/TTS/vits/duration_predictor.py b/egs/ljspeech/TTS/vits/duration_predictor.py
new file mode 100644
index 000000000..c29a28479
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/duration_predictor.py
@@ -0,0 +1,194 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Stochastic duration predictor modules in VITS.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+
+from flow import (
+ ConvFlow,
+ DilatedDepthSeparableConv,
+ ElementwiseAffineFlow,
+ FlipFlow,
+ LogFlow,
+)
+
+
+class StochasticDurationPredictor(torch.nn.Module):
+ """Stochastic duration predictor module.
+
+ This is a module of stochastic duration predictor described in `Conditional
+ Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
+
+ .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
+ Text-to-Speech`: https://arxiv.org/abs/2006.04558
+
+ """
+
+ def __init__(
+ self,
+ channels: int = 192,
+ kernel_size: int = 3,
+ dropout_rate: float = 0.5,
+ flows: int = 4,
+ dds_conv_layers: int = 3,
+ global_channels: int = -1,
+ ):
+ """Initialize StochasticDurationPredictor module.
+
+ Args:
+ channels (int): Number of channels.
+ kernel_size (int): Kernel size.
+ dropout_rate (float): Dropout rate.
+ flows (int): Number of flows.
+ dds_conv_layers (int): Number of conv layers in DDS conv.
+ global_channels (int): Number of global conditioning channels.
+
+ """
+ super().__init__()
+
+ self.pre = torch.nn.Conv1d(channels, channels, 1)
+ self.dds = DilatedDepthSeparableConv(
+ channels,
+ kernel_size,
+ layers=dds_conv_layers,
+ dropout_rate=dropout_rate,
+ )
+ self.proj = torch.nn.Conv1d(channels, channels, 1)
+
+ self.log_flow = LogFlow()
+ self.flows = torch.nn.ModuleList()
+ self.flows += [ElementwiseAffineFlow(2)]
+ for i in range(flows):
+ self.flows += [
+ ConvFlow(
+ 2,
+ channels,
+ kernel_size,
+ layers=dds_conv_layers,
+ )
+ ]
+ self.flows += [FlipFlow()]
+
+ self.post_pre = torch.nn.Conv1d(1, channels, 1)
+ self.post_dds = DilatedDepthSeparableConv(
+ channels,
+ kernel_size,
+ layers=dds_conv_layers,
+ dropout_rate=dropout_rate,
+ )
+ self.post_proj = torch.nn.Conv1d(channels, channels, 1)
+ self.post_flows = torch.nn.ModuleList()
+ self.post_flows += [ElementwiseAffineFlow(2)]
+ for i in range(flows):
+ self.post_flows += [
+ ConvFlow(
+ 2,
+ channels,
+ kernel_size,
+ layers=dds_conv_layers,
+ )
+ ]
+ self.post_flows += [FlipFlow()]
+
+ if global_channels > 0:
+ self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ w: Optional[torch.Tensor] = None,
+ g: Optional[torch.Tensor] = None,
+ inverse: bool = False,
+ noise_scale: float = 1.0,
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T_text).
+ x_mask (Tensor): Mask tensor (B, 1, T_text).
+ w (Optional[Tensor]): Duration tensor (B, 1, T_text).
+ g (Optional[Tensor]): Global conditioning tensor (B, channels, 1)
+ inverse (bool): Whether to inverse the flow.
+ noise_scale (float): Noise scale value.
+
+ Returns:
+ Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,).
+ If inverse, log-duration tensor (B, 1, T_text).
+
+ """
+ x = x.detach() # stop gradient
+ x = self.pre(x)
+ if g is not None:
+ x = x + self.global_conv(g.detach()) # stop gradient
+ x = self.dds(x, x_mask)
+ x = self.proj(x) * x_mask
+
+ if not inverse:
+ assert w is not None, "w must be provided."
+ h_w = self.post_pre(w)
+ h_w = self.post_dds(h_w, x_mask)
+ h_w = self.post_proj(h_w) * x_mask
+ e_q = (
+ torch.randn(
+ w.size(0),
+ 2,
+ w.size(2),
+ ).to(device=x.device, dtype=x.dtype)
+ * x_mask
+ )
+ z_q = e_q
+ logdet_tot_q = 0.0
+ for flow in self.post_flows:
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
+ logdet_tot_q += logdet_q
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
+ u = torch.sigmoid(z_u) * x_mask
+ z0 = (w - u) * x_mask
+ logdet_tot_q += torch.sum(
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
+ )
+ logq = (
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
+ - logdet_tot_q
+ )
+
+ logdet_tot = 0
+ z0, logdet = self.log_flow(z0, x_mask)
+ logdet_tot += logdet
+ z = torch.cat([z0, z1], 1)
+ for flow in self.flows:
+ z, logdet = flow(z, x_mask, g=x, inverse=inverse)
+ logdet_tot = logdet_tot + logdet
+ nll = (
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
+ - logdet_tot
+ )
+ return nll + logq # (B,)
+ else:
+ flows = list(reversed(self.flows))
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
+ z = (
+ torch.randn(
+ x.size(0),
+ 2,
+ x.size(2),
+ ).to(device=x.device, dtype=x.dtype)
+ * noise_scale
+ )
+ for flow in flows:
+ z = flow(z, x_mask, g=x, inverse=inverse)
+ z0, z1 = z.split(1, 1)
+ logw = z0
+ return logw
diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py
new file mode 100755
index 000000000..154de4bf4
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/export-onnx.py
@@ -0,0 +1,261 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script exports a VITS model from PyTorch to ONNX.
+
+Export the model to ONNX:
+./vits/export-onnx.py \
+ --epoch 1000 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+
+It will generate two files inside vits/exp:
+ - vits-epoch-1000.onnx
+ - vits-epoch-1000.int8.onnx (quantizated model)
+
+See ./test_onnx.py for how to use the exported ONNX models.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Dict, Tuple
+
+import onnx
+import torch
+import torch.nn as nn
+from onnxruntime.quantization import QuantType, quantize_dynamic
+from tokenizer import Tokenizer
+from train import get_model, get_params
+
+from icefall.checkpoint import load_checkpoint
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=1000,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="vits/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ return parser
+
+
+def add_meta_data(filename: str, meta_data: Dict[str, str]):
+ """Add meta data to an ONNX model. It is changed in-place.
+
+ Args:
+ filename:
+ Filename of the ONNX model to be changed.
+ meta_data:
+ Key-value pairs.
+ """
+ model = onnx.load(filename)
+ for key, value in meta_data.items():
+ meta = model.metadata_props.add()
+ meta.key = key
+ meta.value = value
+
+ onnx.save(model, filename)
+
+
+class OnnxModel(nn.Module):
+ """A wrapper for VITS generator."""
+
+ def __init__(self, model: nn.Module):
+ """
+ Args:
+ model:
+ A VITS generator.
+ frame_shift:
+ The frame shift in samples.
+ """
+ super().__init__()
+ self.model = model
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ tokens_lens: torch.Tensor,
+ noise_scale: float = 0.667,
+ noise_scale_dur: float = 0.8,
+ alpha: float = 1.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Please see the help information of VITS.inference_batch
+
+ Args:
+ tokens:
+ Input text token indexes (1, T_text)
+ tokens_lens:
+ Number of tokens of shape (1,)
+ noise_scale (float):
+ Noise scale parameter for flow.
+ noise_scale_dur (float):
+ Noise scale parameter for duration predictor.
+ alpha (float):
+ Alpha parameter to control the speed of generated speech.
+
+ Returns:
+ Return a tuple containing:
+ - audio, generated wavform tensor, (B, T_wav)
+ """
+ audio, _, _ = self.model.inference(
+ text=tokens,
+ text_lengths=tokens_lens,
+ noise_scale=noise_scale,
+ noise_scale_dur=noise_scale_dur,
+ alpha=alpha,
+ )
+ return audio
+
+
+def export_model_onnx(
+ model: nn.Module,
+ model_filename: str,
+ opset_version: int = 11,
+) -> None:
+ """Export the given generator model to ONNX format.
+ The exported model has one input:
+
+ - tokens, a tensor of shape (1, T_text); dtype is torch.int64
+
+ and it has one output:
+
+ - audio, a tensor of shape (1, T'); dtype is torch.float32
+
+ Args:
+ model:
+ The VITS generator.
+ model_filename:
+ The filename to save the exported ONNX model.
+ opset_version:
+ The opset version to use.
+ """
+ tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
+ tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
+ noise_scale = torch.tensor([1], dtype=torch.float32)
+ noise_scale_dur = torch.tensor([1], dtype=torch.float32)
+ alpha = torch.tensor([1], dtype=torch.float32)
+
+ torch.onnx.export(
+ model,
+ (tokens, tokens_lens, noise_scale, noise_scale_dur, alpha),
+ model_filename,
+ verbose=False,
+ opset_version=opset_version,
+ input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"],
+ output_names=["audio"],
+ dynamic_axes={
+ "tokens": {0: "N", 1: "T"},
+ "tokens_lens": {0: "N"},
+ "audio": {0: "N", 1: "T"},
+ },
+ )
+
+ meta_data = {
+ "model_type": "VITS",
+ "version": "1",
+ "model_author": "k2-fsa",
+ "comment": "VITS generator",
+ }
+ logging.info(f"meta_data: {meta_data}")
+
+ add_meta_data(filename=model_filename, meta_data=meta_data)
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ tokenizer = Tokenizer(params.tokens)
+ params.blank_id = tokenizer.blank_id
+ params.oov_id = tokenizer.oov_id
+ params.vocab_size = tokenizer.vocab_size
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+
+ model = model.generator
+ model.to("cpu")
+ model.eval()
+
+ model = OnnxModel(model=model)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"generator parameters: {num_param}")
+
+ suffix = f"epoch-{params.epoch}"
+
+ opset_version = 13
+
+ logging.info("Exporting encoder")
+ model_filename = params.exp_dir / f"vits-{suffix}.onnx"
+ export_model_onnx(
+ model,
+ model_filename,
+ opset_version=opset_version,
+ )
+ logging.info(f"Exported generator to {model_filename}")
+
+ # Generate int8 quantization models
+ # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
+
+ logging.info("Generate int8 quantization models")
+
+ model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
+ quantize_dynamic(
+ model_input=model_filename,
+ model_output=model_filename_int8,
+ weight_type=QuantType.QUInt8,
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/ljspeech/TTS/vits/flow.py b/egs/ljspeech/TTS/vits/flow.py
new file mode 100644
index 000000000..206bd5e3e
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/flow.py
@@ -0,0 +1,312 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Basic Flow modules used in VITS.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+
+from transform import piecewise_rational_quadratic_transform
+
+
+class FlipFlow(torch.nn.Module):
+ """Flip flow module."""
+
+ def forward(
+ self, x: torch.Tensor, *args, inverse: bool = False, **kwargs
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T).
+ inverse (bool): Whether to inverse the flow.
+
+ Returns:
+ Tensor: Flipped tensor (B, channels, T).
+ Tensor: Log-determinant tensor for NLL (B,) if not inverse.
+
+ """
+ x = torch.flip(x, [1])
+ if not inverse:
+ logdet = x.new_zeros(x.size(0))
+ return x, logdet
+ else:
+ return x
+
+
+class LogFlow(torch.nn.Module):
+ """Log flow module."""
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ inverse: bool = False,
+ eps: float = 1e-5,
+ **kwargs
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T).
+ x_mask (Tensor): Mask tensor (B, 1, T).
+ inverse (bool): Whether to inverse the flow.
+ eps (float): Epsilon for log.
+
+ Returns:
+ Tensor: Output tensor (B, channels, T).
+ Tensor: Log-determinant tensor for NLL (B,) if not inverse.
+
+ """
+ if not inverse:
+ y = torch.log(torch.clamp_min(x, eps)) * x_mask
+ logdet = torch.sum(-y, [1, 2])
+ return y, logdet
+ else:
+ x = torch.exp(x) * x_mask
+ return x
+
+
+class ElementwiseAffineFlow(torch.nn.Module):
+ """Elementwise affine flow module."""
+
+ def __init__(self, channels: int):
+ """Initialize ElementwiseAffineFlow module.
+
+ Args:
+ channels (int): Number of channels.
+
+ """
+ super().__init__()
+ self.channels = channels
+ self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1)))
+ self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1)))
+
+ def forward(
+ self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T).
+ x_lengths (Tensor): Length tensor (B,).
+ inverse (bool): Whether to inverse the flow.
+
+ Returns:
+ Tensor: Output tensor (B, channels, T).
+ Tensor: Log-determinant tensor for NLL (B,) if not inverse.
+
+ """
+ if not inverse:
+ y = self.m + torch.exp(self.logs) * x
+ y = y * x_mask
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
+ return y, logdet
+ else:
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
+ return x
+
+
+class Transpose(torch.nn.Module):
+ """Transpose module for torch.nn.Sequential()."""
+
+ def __init__(self, dim1: int, dim2: int):
+ """Initialize Transpose module."""
+ super().__init__()
+ self.dim1 = dim1
+ self.dim2 = dim2
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Transpose."""
+ return x.transpose(self.dim1, self.dim2)
+
+
+class DilatedDepthSeparableConv(torch.nn.Module):
+ """Dilated depth-separable conv module."""
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ layers: int,
+ dropout_rate: float = 0.0,
+ eps: float = 1e-5,
+ ):
+ """Initialize DilatedDepthSeparableConv module.
+
+ Args:
+ channels (int): Number of channels.
+ kernel_size (int): Kernel size.
+ layers (int): Number of layers.
+ dropout_rate (float): Dropout rate.
+ eps (float): Epsilon for layer norm.
+
+ """
+ super().__init__()
+
+ self.convs = torch.nn.ModuleList()
+ for i in range(layers):
+ dilation = kernel_size**i
+ padding = (kernel_size * dilation - dilation) // 2
+ self.convs += [
+ torch.nn.Sequential(
+ torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ groups=channels,
+ dilation=dilation,
+ padding=padding,
+ ),
+ Transpose(1, 2),
+ torch.nn.LayerNorm(
+ channels,
+ eps=eps,
+ elementwise_affine=True,
+ ),
+ Transpose(1, 2),
+ torch.nn.GELU(),
+ torch.nn.Conv1d(
+ channels,
+ channels,
+ 1,
+ ),
+ Transpose(1, 2),
+ torch.nn.LayerNorm(
+ channels,
+ eps=eps,
+ elementwise_affine=True,
+ ),
+ Transpose(1, 2),
+ torch.nn.GELU(),
+ torch.nn.Dropout(dropout_rate),
+ )
+ ]
+
+ def forward(
+ self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, in_channels, T).
+ x_mask (Tensor): Mask tensor (B, 1, T).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+
+ Returns:
+ Tensor: Output tensor (B, channels, T).
+
+ """
+ if g is not None:
+ x = x + g
+ for f in self.convs:
+ y = f(x * x_mask)
+ x = x + y
+ return x * x_mask
+
+
+class ConvFlow(torch.nn.Module):
+ """Convolutional flow module."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ hidden_channels: int,
+ kernel_size: int,
+ layers: int,
+ bins: int = 10,
+ tail_bound: float = 5.0,
+ ):
+ """Initialize ConvFlow module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ hidden_channels (int): Number of hidden channels.
+ kernel_size (int): Kernel size.
+ layers (int): Number of layers.
+ bins (int): Number of bins.
+ tail_bound (float): Tail bound value.
+
+ """
+ super().__init__()
+ self.half_channels = in_channels // 2
+ self.hidden_channels = hidden_channels
+ self.bins = bins
+ self.tail_bound = tail_bound
+
+ self.input_conv = torch.nn.Conv1d(
+ self.half_channels,
+ hidden_channels,
+ 1,
+ )
+ self.dds_conv = DilatedDepthSeparableConv(
+ hidden_channels,
+ kernel_size,
+ layers,
+ dropout_rate=0.0,
+ )
+ self.proj = torch.nn.Conv1d(
+ hidden_channels,
+ self.half_channels * (bins * 3 - 1),
+ 1,
+ )
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ g: Optional[torch.Tensor] = None,
+ inverse: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T).
+ x_mask (Tensor): Mask tensor (B,).
+ g (Optional[Tensor]): Global conditioning tensor (B, channels, 1).
+ inverse (bool): Whether to inverse the flow.
+
+ Returns:
+ Tensor: Output tensor (B, channels, T).
+ Tensor: Log-determinant tensor for NLL (B,) if not inverse.
+
+ """
+ xa, xb = x.split(x.size(1) // 2, 1)
+ h = self.input_conv(xa)
+ h = self.dds_conv(h, x_mask, g=g)
+ h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T)
+
+ b, c, t = xa.shape
+ # (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1)
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2)
+
+ # TODO(kan-bayashi): Understand this calculation
+ denom = math.sqrt(self.hidden_channels)
+ unnorm_widths = h[..., : self.bins] / denom
+ unnorm_heights = h[..., self.bins : 2 * self.bins] / denom
+ unnorm_derivatives = h[..., 2 * self.bins :]
+ xb, logdet_abs = piecewise_rational_quadratic_transform(
+ xb,
+ unnorm_widths,
+ unnorm_heights,
+ unnorm_derivatives,
+ inverse=inverse,
+ tails="linear",
+ tail_bound=self.tail_bound,
+ )
+ x = torch.cat([xa, xb], 1) * x_mask
+ logdet = torch.sum(logdet_abs * x_mask, [1, 2])
+ if not inverse:
+ return x, logdet
+ else:
+ return x
diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py
new file mode 100644
index 000000000..efb0e254c
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/generator.py
@@ -0,0 +1,531 @@
+# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Generator module in VITS.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+
+import math
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from icefall.utils import make_pad_mask
+
+from duration_predictor import StochasticDurationPredictor
+from hifigan import HiFiGANGenerator
+from posterior_encoder import PosteriorEncoder
+from residual_coupling import ResidualAffineCouplingBlock
+from text_encoder import TextEncoder
+from utils import get_random_segments
+
+
+class VITSGenerator(torch.nn.Module):
+ """Generator module in VITS, `Conditional Variational Autoencoder
+ with Adversarial Learning for End-to-End Text-to-Speech`.
+ """
+
+ def __init__(
+ self,
+ vocabs: int,
+ aux_channels: int = 513,
+ hidden_channels: int = 192,
+ spks: Optional[int] = None,
+ langs: Optional[int] = None,
+ spk_embed_dim: Optional[int] = None,
+ global_channels: int = -1,
+ segment_size: int = 32,
+ text_encoder_attention_heads: int = 2,
+ text_encoder_ffn_expand: int = 4,
+ text_encoder_cnn_module_kernel: int = 5,
+ text_encoder_blocks: int = 6,
+ text_encoder_dropout_rate: float = 0.1,
+ decoder_kernel_size: int = 7,
+ decoder_channels: int = 512,
+ decoder_upsample_scales: List[int] = [8, 8, 2, 2],
+ decoder_upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
+ decoder_resblock_kernel_sizes: List[int] = [3, 7, 11],
+ decoder_resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ use_weight_norm_in_decoder: bool = True,
+ posterior_encoder_kernel_size: int = 5,
+ posterior_encoder_layers: int = 16,
+ posterior_encoder_stacks: int = 1,
+ posterior_encoder_base_dilation: int = 1,
+ posterior_encoder_dropout_rate: float = 0.0,
+ use_weight_norm_in_posterior_encoder: bool = True,
+ flow_flows: int = 4,
+ flow_kernel_size: int = 5,
+ flow_base_dilation: int = 1,
+ flow_layers: int = 4,
+ flow_dropout_rate: float = 0.0,
+ use_weight_norm_in_flow: bool = True,
+ use_only_mean_in_flow: bool = True,
+ stochastic_duration_predictor_kernel_size: int = 3,
+ stochastic_duration_predictor_dropout_rate: float = 0.5,
+ stochastic_duration_predictor_flows: int = 4,
+ stochastic_duration_predictor_dds_conv_layers: int = 3,
+ ):
+ """Initialize VITS generator module.
+
+ Args:
+ vocabs (int): Input vocabulary size.
+ aux_channels (int): Number of acoustic feature channels.
+ hidden_channels (int): Number of hidden channels.
+ spks (Optional[int]): Number of speakers. If set to > 1, assume that the
+ sids will be provided as the input and use sid embedding layer.
+ langs (Optional[int]): Number of languages. If set to > 1, assume that the
+ lids will be provided as the input and use sid embedding layer.
+ spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0,
+ assume that spembs will be provided as the input.
+ global_channels (int): Number of global conditioning channels.
+ segment_size (int): Segment size for decoder.
+ text_encoder_attention_heads (int): Number of heads in conformer block
+ of text encoder.
+ text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
+ of text encoder.
+ text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder.
+ text_encoder_blocks (int): Number of conformer blocks in text encoder.
+ text_encoder_dropout_rate (float): Dropout rate in conformer block of
+ text encoder.
+ decoder_kernel_size (int): Decoder kernel size.
+ decoder_channels (int): Number of decoder initial channels.
+ decoder_upsample_scales (List[int]): List of upsampling scales in decoder.
+ decoder_upsample_kernel_sizes (List[int]): List of kernel size for
+ upsampling layers in decoder.
+ decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks
+ in decoder.
+ decoder_resblock_dilations (List[List[int]]): List of list of dilations for
+ resblocks in decoder.
+ use_weight_norm_in_decoder (bool): Whether to apply weight normalization in
+ decoder.
+ posterior_encoder_kernel_size (int): Posterior encoder kernel size.
+ posterior_encoder_layers (int): Number of layers of posterior encoder.
+ posterior_encoder_stacks (int): Number of stacks of posterior encoder.
+ posterior_encoder_base_dilation (int): Base dilation of posterior encoder.
+ posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder.
+ use_weight_norm_in_posterior_encoder (bool): Whether to apply weight
+ normalization in posterior encoder.
+ flow_flows (int): Number of flows in flow.
+ flow_kernel_size (int): Kernel size in flow.
+ flow_base_dilation (int): Base dilation in flow.
+ flow_layers (int): Number of layers in flow.
+ flow_dropout_rate (float): Dropout rate in flow
+ use_weight_norm_in_flow (bool): Whether to apply weight normalization in
+ flow.
+ use_only_mean_in_flow (bool): Whether to use only mean in flow.
+ stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic
+ duration predictor.
+ stochastic_duration_predictor_dropout_rate (float): Dropout rate in
+ stochastic duration predictor.
+ stochastic_duration_predictor_flows (int): Number of flows in stochastic
+ duration predictor.
+ stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv
+ layers in stochastic duration predictor.
+
+ """
+ super().__init__()
+ self.segment_size = segment_size
+ self.text_encoder = TextEncoder(
+ vocabs=vocabs,
+ d_model=hidden_channels,
+ num_heads=text_encoder_attention_heads,
+ dim_feedforward=hidden_channels * text_encoder_ffn_expand,
+ cnn_module_kernel=text_encoder_cnn_module_kernel,
+ num_layers=text_encoder_blocks,
+ dropout=text_encoder_dropout_rate,
+ )
+ self.decoder = HiFiGANGenerator(
+ in_channels=hidden_channels,
+ out_channels=1,
+ channels=decoder_channels,
+ global_channels=global_channels,
+ kernel_size=decoder_kernel_size,
+ upsample_scales=decoder_upsample_scales,
+ upsample_kernel_sizes=decoder_upsample_kernel_sizes,
+ resblock_kernel_sizes=decoder_resblock_kernel_sizes,
+ resblock_dilations=decoder_resblock_dilations,
+ use_weight_norm=use_weight_norm_in_decoder,
+ )
+ self.posterior_encoder = PosteriorEncoder(
+ in_channels=aux_channels,
+ out_channels=hidden_channels,
+ hidden_channels=hidden_channels,
+ kernel_size=posterior_encoder_kernel_size,
+ layers=posterior_encoder_layers,
+ stacks=posterior_encoder_stacks,
+ base_dilation=posterior_encoder_base_dilation,
+ global_channels=global_channels,
+ dropout_rate=posterior_encoder_dropout_rate,
+ use_weight_norm=use_weight_norm_in_posterior_encoder,
+ )
+ self.flow = ResidualAffineCouplingBlock(
+ in_channels=hidden_channels,
+ hidden_channels=hidden_channels,
+ flows=flow_flows,
+ kernel_size=flow_kernel_size,
+ base_dilation=flow_base_dilation,
+ layers=flow_layers,
+ global_channels=global_channels,
+ dropout_rate=flow_dropout_rate,
+ use_weight_norm=use_weight_norm_in_flow,
+ use_only_mean=use_only_mean_in_flow,
+ )
+ # TODO(kan-bayashi): Add deterministic version as an option
+ self.duration_predictor = StochasticDurationPredictor(
+ channels=hidden_channels,
+ kernel_size=stochastic_duration_predictor_kernel_size,
+ dropout_rate=stochastic_duration_predictor_dropout_rate,
+ flows=stochastic_duration_predictor_flows,
+ dds_conv_layers=stochastic_duration_predictor_dds_conv_layers,
+ global_channels=global_channels,
+ )
+
+ self.upsample_factor = int(np.prod(decoder_upsample_scales))
+ self.spks = None
+ if spks is not None and spks > 1:
+ assert global_channels > 0
+ self.spks = spks
+ self.global_emb = torch.nn.Embedding(spks, global_channels)
+ self.spk_embed_dim = None
+ if spk_embed_dim is not None and spk_embed_dim > 0:
+ assert global_channels > 0
+ self.spk_embed_dim = spk_embed_dim
+ self.spemb_proj = torch.nn.Linear(spk_embed_dim, global_channels)
+ self.langs = None
+ if langs is not None and langs > 1:
+ assert global_channels > 0
+ self.langs = langs
+ self.lang_emb = torch.nn.Embedding(langs, global_channels)
+
+ # delayed import
+ from monotonic_align import maximum_path
+
+ self.maximum_path = maximum_path
+
+ def forward(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ feats: torch.Tensor,
+ feats_lengths: torch.Tensor,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ ) -> Tuple[
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ Tuple[
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ ],
+ ]:
+ """Calculate forward propagation.
+
+ Args:
+ text (Tensor): Text index tensor (B, T_text).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats (Tensor): Feature tensor (B, aux_channels, T_feats).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
+ spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
+ lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
+
+ Returns:
+ Tensor: Waveform tensor (B, 1, segment_size * upsample_factor).
+ Tensor: Duration negative log-likelihood (NLL) tensor (B,).
+ Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text).
+ Tensor: Segments start index tensor (B,).
+ Tensor: Text mask tensor (B, 1, T_text).
+ Tensor: Feature mask tensor (B, 1, T_feats).
+ tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
+ - Tensor: Posterior encoder hidden representation (B, H, T_feats).
+ - Tensor: Flow hidden representation (B, H, T_feats).
+ - Tensor: Expanded text encoder projected mean (B, H, T_feats).
+ - Tensor: Expanded text encoder projected scale (B, H, T_feats).
+ - Tensor: Posterior encoder projected mean (B, H, T_feats).
+ - Tensor: Posterior encoder projected scale (B, H, T_feats).
+
+ """
+ # forward text encoder
+ x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
+
+ # calculate global conditioning
+ g = None
+ if self.spks is not None:
+ # speaker one-hot vector embedding: (B, global_channels, 1)
+ g = self.global_emb(sids.view(-1)).unsqueeze(-1)
+ if self.spk_embed_dim is not None:
+ # pretreined speaker embedding, e.g., X-vector (B, global_channels, 1)
+ g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
+ if g is None:
+ g = g_
+ else:
+ g = g + g_
+ if self.langs is not None:
+ # language one-hot vector embedding: (B, global_channels, 1)
+ g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
+ if g is None:
+ g = g_
+ else:
+ g = g + g_
+
+ # forward posterior encoder
+ z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
+
+ # forward flow
+ z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
+
+ # monotonic alignment search
+ with torch.no_grad():
+ # negative cross-entropy
+ s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
+ # (B, 1, T_text)
+ neg_x_ent_1 = torch.sum(
+ -0.5 * math.log(2 * math.pi) - logs_p,
+ [1],
+ keepdim=True,
+ )
+ # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
+ neg_x_ent_2 = torch.matmul(
+ -0.5 * (z_p**2).transpose(1, 2),
+ s_p_sq_r,
+ )
+ # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
+ neg_x_ent_3 = torch.matmul(
+ z_p.transpose(1, 2),
+ (m_p * s_p_sq_r),
+ )
+ # (B, 1, T_text)
+ neg_x_ent_4 = torch.sum(
+ -0.5 * (m_p**2) * s_p_sq_r,
+ [1],
+ keepdim=True,
+ )
+ # (B, T_feats, T_text)
+ neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
+ # (B, 1, T_feats, T_text)
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+ # monotonic attention weight: (B, 1, T_feats, T_text)
+ attn = (
+ self.maximum_path(
+ neg_x_ent,
+ attn_mask.squeeze(1),
+ )
+ .unsqueeze(1)
+ .detach()
+ )
+
+ # forward duration predictor
+ w = attn.sum(2) # (B, 1, T_text)
+ dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
+ dur_nll = dur_nll / torch.sum(x_mask)
+
+ # expand the length to match with the feature sequence
+ # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
+ # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
+
+ # get random segments
+ z_segments, z_start_idxs = get_random_segments(
+ z,
+ feats_lengths,
+ self.segment_size,
+ )
+
+ # forward decoder with random segments
+ wav = self.decoder(z_segments, g=g)
+
+ return (
+ wav,
+ dur_nll,
+ attn,
+ z_start_idxs,
+ x_mask,
+ y_mask,
+ (z, z_p, m_p, logs_p, m_q, logs_q),
+ )
+
+ def inference(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ feats: Optional[torch.Tensor] = None,
+ feats_lengths: Optional[torch.Tensor] = None,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ dur: Optional[torch.Tensor] = None,
+ noise_scale: float = 0.667,
+ noise_scale_dur: float = 0.8,
+ alpha: float = 1.0,
+ max_len: Optional[int] = None,
+ use_teacher_forcing: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run inference.
+
+ Args:
+ text (Tensor): Input text index tensor (B, T_text,).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
+ spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
+ lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
+ dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided,
+ skip the prediction of durations (i.e., teacher forcing).
+ noise_scale (float): Noise scale parameter for flow.
+ noise_scale_dur (float): Noise scale parameter for duration predictor.
+ alpha (float): Alpha parameter to control the speed of generated speech.
+ max_len (Optional[int]): Maximum length of acoustic feature sequence.
+ use_teacher_forcing (bool): Whether to use teacher forcing.
+
+ Returns:
+ Tensor: Generated waveform tensor (B, T_wav).
+ Tensor: Monotonic attention weight tensor (B, T_feats, T_text).
+ Tensor: Duration tensor (B, T_text).
+
+ """
+ # encoder
+ x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
+ x_mask = x_mask.to(x.dtype)
+ g = None
+ if self.spks is not None:
+ # (B, global_channels, 1)
+ g = self.global_emb(sids.view(-1)).unsqueeze(-1)
+ if self.spk_embed_dim is not None:
+ # (B, global_channels, 1)
+ g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
+ if g is None:
+ g = g_
+ else:
+ g = g + g_
+ if self.langs is not None:
+ # (B, global_channels, 1)
+ g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
+ if g is None:
+ g = g_
+ else:
+ g = g + g_
+
+ if use_teacher_forcing:
+ # forward posterior encoder
+ z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
+
+ # forward flow
+ z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
+
+ # monotonic alignment search
+ s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
+ # (B, 1, T_text)
+ neg_x_ent_1 = torch.sum(
+ -0.5 * math.log(2 * math.pi) - logs_p,
+ [1],
+ keepdim=True,
+ )
+ # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
+ neg_x_ent_2 = torch.matmul(
+ -0.5 * (z_p**2).transpose(1, 2),
+ s_p_sq_r,
+ )
+ # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
+ neg_x_ent_3 = torch.matmul(
+ z_p.transpose(1, 2),
+ (m_p * s_p_sq_r),
+ )
+ # (B, 1, T_text)
+ neg_x_ent_4 = torch.sum(
+ -0.5 * (m_p**2) * s_p_sq_r,
+ [1],
+ keepdim=True,
+ )
+ # (B, T_feats, T_text)
+ neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
+ # (B, 1, T_feats, T_text)
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+ # monotonic attention weight: (B, 1, T_feats, T_text)
+ attn = self.maximum_path(
+ neg_x_ent,
+ attn_mask.squeeze(1),
+ ).unsqueeze(1)
+ dur = attn.sum(2) # (B, 1, T_text)
+
+ # forward decoder with random segments
+ wav = self.decoder(z * y_mask, g=g)
+ else:
+ # duration
+ if dur is None:
+ logw = self.duration_predictor(
+ x,
+ x_mask,
+ g=g,
+ inverse=True,
+ noise_scale=noise_scale_dur,
+ )
+ w = torch.exp(logw) * x_mask * alpha
+ dur = torch.ceil(w)
+ y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long()
+ y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device)
+ y_mask = y_mask.to(x.dtype)
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+ attn = self._generate_path(dur, attn_mask)
+
+ # expand the length to match with the feature sequence
+ # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
+ m_p = torch.matmul(
+ attn.squeeze(1),
+ m_p.transpose(1, 2),
+ ).transpose(1, 2)
+ # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
+ logs_p = torch.matmul(
+ attn.squeeze(1),
+ logs_p.transpose(1, 2),
+ ).transpose(1, 2)
+
+ # decoder
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
+ z = self.flow(z_p, y_mask, g=g, inverse=True)
+ wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)
+
+ return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
+
+ def _generate_path(self, dur: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ """Generate path a.k.a. monotonic attention.
+
+ Args:
+ dur (Tensor): Duration tensor (B, 1, T_text).
+ mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text).
+
+ Returns:
+ Tensor: Path tensor (B, 1, T_feats, T_text).
+
+ """
+ b, _, t_y, t_x = mask.shape
+ cum_dur = torch.cumsum(dur, -1)
+ cum_dur_flat = cum_dur.view(b * t_x)
+ path = torch.arange(t_y, dtype=dur.dtype, device=dur.device)
+ path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
+ # path = path.view(b, t_x, t_y).to(dtype=mask.dtype)
+ path = path.view(b, t_x, t_y).to(dtype=torch.float)
+ # path will be like (t_x = 3, t_y = 5):
+ # [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
+ # [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
+ # [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
+ path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
+ # path = path.to(dtype=mask.dtype)
+ return path.unsqueeze(1).transpose(2, 3) * mask
diff --git a/egs/ljspeech/TTS/vits/hifigan.py b/egs/ljspeech/TTS/vits/hifigan.py
new file mode 100644
index 000000000..589ac30f6
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/hifigan.py
@@ -0,0 +1,933 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""HiFi-GAN Modules.
+
+This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
+
+"""
+
+import copy
+import logging
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+class HiFiGANGenerator(torch.nn.Module):
+ """HiFiGAN generator module."""
+
+ def __init__(
+ self,
+ in_channels: int = 80,
+ out_channels: int = 1,
+ channels: int = 512,
+ global_channels: int = -1,
+ kernel_size: int = 7,
+ upsample_scales: List[int] = [8, 8, 2, 2],
+ upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
+ resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ use_additional_convs: bool = True,
+ bias: bool = True,
+ nonlinear_activation: str = "LeakyReLU",
+ nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
+ use_weight_norm: bool = True,
+ ):
+ """Initialize HiFiGANGenerator module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ channels (int): Number of hidden representation channels.
+ global_channels (int): Number of global conditioning channels.
+ kernel_size (int): Kernel size of initial and final conv layer.
+ upsample_scales (List[int]): List of upsampling scales.
+ upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers.
+ resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks.
+ resblock_dilations (List[List[int]]): List of list of dilations for residual
+ blocks.
+ use_additional_convs (bool): Whether to use additional conv layers in
+ residual blocks.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
+ function.
+ use_weight_norm (bool): Whether to use weight norm. If set to true, it will
+ be applied to all of the conv layers.
+
+ """
+ super().__init__()
+
+ # check hyperparameters are valid
+ assert kernel_size % 2 == 1, "Kernel size must be odd number."
+ assert len(upsample_scales) == len(upsample_kernel_sizes)
+ assert len(resblock_dilations) == len(resblock_kernel_sizes)
+
+ # define modules
+ self.upsample_factor = int(np.prod(upsample_scales) * out_channels)
+ self.num_upsamples = len(upsample_kernel_sizes)
+ self.num_blocks = len(resblock_kernel_sizes)
+ self.input_conv = torch.nn.Conv1d(
+ in_channels,
+ channels,
+ kernel_size,
+ 1,
+ padding=(kernel_size - 1) // 2,
+ )
+ self.upsamples = torch.nn.ModuleList()
+ self.blocks = torch.nn.ModuleList()
+ for i in range(len(upsample_kernel_sizes)):
+ assert upsample_kernel_sizes[i] == 2 * upsample_scales[i]
+ self.upsamples += [
+ torch.nn.Sequential(
+ getattr(torch.nn, nonlinear_activation)(
+ **nonlinear_activation_params
+ ),
+ torch.nn.ConvTranspose1d(
+ channels // (2**i),
+ channels // (2 ** (i + 1)),
+ upsample_kernel_sizes[i],
+ upsample_scales[i],
+ padding=upsample_scales[i] // 2 + upsample_scales[i] % 2,
+ output_padding=upsample_scales[i] % 2,
+ ),
+ )
+ ]
+ for j in range(len(resblock_kernel_sizes)):
+ self.blocks += [
+ ResidualBlock(
+ kernel_size=resblock_kernel_sizes[j],
+ channels=channels // (2 ** (i + 1)),
+ dilations=resblock_dilations[j],
+ bias=bias,
+ use_additional_convs=use_additional_convs,
+ nonlinear_activation=nonlinear_activation,
+ nonlinear_activation_params=nonlinear_activation_params,
+ )
+ ]
+ self.output_conv = torch.nn.Sequential(
+ # NOTE(kan-bayashi): follow official implementation but why
+ # using different slope parameter here? (0.1 vs. 0.01)
+ torch.nn.LeakyReLU(),
+ torch.nn.Conv1d(
+ channels // (2 ** (i + 1)),
+ out_channels,
+ kernel_size,
+ 1,
+ padding=(kernel_size - 1) // 2,
+ ),
+ torch.nn.Tanh(),
+ )
+ if global_channels > 0:
+ self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
+
+ # apply weight norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ # reset parameters
+ self.reset_parameters()
+
+ def forward(
+ self, c: torch.Tensor, g: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ c (Tensor): Input tensor (B, in_channels, T).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+
+ Returns:
+ Tensor: Output tensor (B, out_channels, T).
+
+ """
+ c = self.input_conv(c)
+ if g is not None:
+ c = c + self.global_conv(g)
+ for i in range(self.num_upsamples):
+ c = self.upsamples[i](c)
+ cs = 0.0 # initialize
+ for j in range(self.num_blocks):
+ cs += self.blocks[i * self.num_blocks + j](c)
+ c = cs / self.num_blocks
+ c = self.output_conv(c)
+
+ return c
+
+ def reset_parameters(self):
+ """Reset parameters.
+
+ This initialization follows the official implementation manner.
+ https://github.com/jik876/hifi-gan/blob/master/models.py
+
+ """
+
+ def _reset_parameters(m: torch.nn.Module):
+ if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
+ m.weight.data.normal_(0.0, 0.01)
+ logging.debug(f"Reset parameters in {m}.")
+
+ self.apply(_reset_parameters)
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m: torch.nn.Module):
+ try:
+ logging.debug(f"Weight norm is removed from {m}.")
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv1d) or isinstance(
+ m, torch.nn.ConvTranspose1d
+ ):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ def inference(
+ self, c: torch.Tensor, g: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Perform inference.
+
+ Args:
+ c (torch.Tensor): Input tensor (T, in_channels).
+ g (Optional[Tensor]): Global conditioning tensor (global_channels, 1).
+
+ Returns:
+ Tensor: Output tensor (T ** upsample_factor, out_channels).
+
+ """
+ if g is not None:
+ g = g.unsqueeze(0)
+ c = self.forward(c.transpose(1, 0).unsqueeze(0), g=g)
+ return c.squeeze(0).transpose(1, 0)
+
+
+class ResidualBlock(torch.nn.Module):
+ """Residual block module in HiFiGAN."""
+
+ def __init__(
+ self,
+ kernel_size: int = 3,
+ channels: int = 512,
+ dilations: List[int] = [1, 3, 5],
+ bias: bool = True,
+ use_additional_convs: bool = True,
+ nonlinear_activation: str = "LeakyReLU",
+ nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
+ ):
+ """Initialize ResidualBlock module.
+
+ Args:
+ kernel_size (int): Kernel size of dilation convolution layer.
+ channels (int): Number of channels for convolution layer.
+ dilations (List[int]): List of dilation factors.
+ use_additional_convs (bool): Whether to use additional convolution layers.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
+ function.
+
+ """
+ super().__init__()
+ self.use_additional_convs = use_additional_convs
+ self.convs1 = torch.nn.ModuleList()
+ if use_additional_convs:
+ self.convs2 = torch.nn.ModuleList()
+ assert kernel_size % 2 == 1, "Kernel size must be odd number."
+ for dilation in dilations:
+ self.convs1 += [
+ torch.nn.Sequential(
+ getattr(torch.nn, nonlinear_activation)(
+ **nonlinear_activation_params
+ ),
+ torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation,
+ bias=bias,
+ padding=(kernel_size - 1) // 2 * dilation,
+ ),
+ )
+ ]
+ if use_additional_convs:
+ self.convs2 += [
+ torch.nn.Sequential(
+ getattr(torch.nn, nonlinear_activation)(
+ **nonlinear_activation_params
+ ),
+ torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ bias=bias,
+ padding=(kernel_size - 1) // 2,
+ ),
+ )
+ ]
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T).
+
+ Returns:
+ Tensor: Output tensor (B, channels, T).
+
+ """
+ for idx in range(len(self.convs1)):
+ xt = self.convs1[idx](x)
+ if self.use_additional_convs:
+ xt = self.convs2[idx](xt)
+ x = xt + x
+ return x
+
+
+class HiFiGANPeriodDiscriminator(torch.nn.Module):
+ """HiFiGAN period discriminator module."""
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ period: int = 3,
+ kernel_sizes: List[int] = [5, 3],
+ channels: int = 32,
+ downsample_scales: List[int] = [3, 3, 3, 3, 1],
+ max_downsample_channels: int = 1024,
+ bias: bool = True,
+ nonlinear_activation: str = "LeakyReLU",
+ nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
+ use_weight_norm: bool = True,
+ use_spectral_norm: bool = False,
+ ):
+ """Initialize HiFiGANPeriodDiscriminator module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ period (int): Period.
+ kernel_sizes (list): Kernel sizes of initial conv layers and the final conv
+ layer.
+ channels (int): Number of initial channels.
+ downsample_scales (List[int]): List of downsampling scales.
+ max_downsample_channels (int): Number of maximum downsampling channels.
+ use_additional_convs (bool): Whether to use additional conv layers in
+ residual blocks.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
+ function.
+ use_weight_norm (bool): Whether to use weight norm.
+ If set to true, it will be applied to all of the conv layers.
+ use_spectral_norm (bool): Whether to use spectral norm.
+ If set to true, it will be applied to all of the conv layers.
+
+ """
+ super().__init__()
+ assert len(kernel_sizes) == 2
+ assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number."
+ assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number."
+
+ self.period = period
+ self.convs = torch.nn.ModuleList()
+ in_chs = in_channels
+ out_chs = channels
+ for downsample_scale in downsample_scales:
+ self.convs += [
+ torch.nn.Sequential(
+ torch.nn.Conv2d(
+ in_chs,
+ out_chs,
+ (kernel_sizes[0], 1),
+ (downsample_scale, 1),
+ padding=((kernel_sizes[0] - 1) // 2, 0),
+ ),
+ getattr(torch.nn, nonlinear_activation)(
+ **nonlinear_activation_params
+ ),
+ )
+ ]
+ in_chs = out_chs
+ # NOTE(kan-bayashi): Use downsample_scale + 1?
+ out_chs = min(out_chs * 4, max_downsample_channels)
+ self.output_conv = torch.nn.Conv2d(
+ out_chs,
+ out_channels,
+ (kernel_sizes[1] - 1, 1),
+ 1,
+ padding=((kernel_sizes[1] - 1) // 2, 0),
+ )
+
+ if use_weight_norm and use_spectral_norm:
+ raise ValueError("Either use use_weight_norm or use_spectral_norm.")
+
+ # apply weight norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ # apply spectral norm
+ if use_spectral_norm:
+ self.apply_spectral_norm()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ c (Tensor): Input tensor (B, in_channels, T).
+
+ Returns:
+ list: List of each layer's tensors.
+
+ """
+ # transform 1d to 2d -> (B, C, T/P, P)
+ b, c, t = x.shape
+ if t % self.period != 0:
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t += n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ # forward conv
+ outs = []
+ for layer in self.convs:
+ x = layer(x)
+ outs += [x]
+ x = self.output_conv(x)
+ x = torch.flatten(x, 1, -1)
+ outs += [x]
+
+ return outs
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv2d):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ def apply_spectral_norm(self):
+ """Apply spectral normalization module from all of the layers."""
+
+ def _apply_spectral_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv2d):
+ torch.nn.utils.spectral_norm(m)
+ logging.debug(f"Spectral norm is applied to {m}.")
+
+ self.apply(_apply_spectral_norm)
+
+
+class HiFiGANMultiPeriodDiscriminator(torch.nn.Module):
+ """HiFiGAN multi-period discriminator module."""
+
+ def __init__(
+ self,
+ periods: List[int] = [2, 3, 5, 7, 11],
+ discriminator_params: Dict[str, Any] = {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [5, 3],
+ "channels": 32,
+ "downsample_scales": [3, 3, 3, 3, 1],
+ "max_downsample_channels": 1024,
+ "bias": True,
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ "use_weight_norm": True,
+ "use_spectral_norm": False,
+ },
+ ):
+ """Initialize HiFiGANMultiPeriodDiscriminator module.
+
+ Args:
+ periods (List[int]): List of periods.
+ discriminator_params (Dict[str, Any]): Parameters for hifi-gan period
+ discriminator module. The period parameter will be overwritten.
+
+ """
+ super().__init__()
+ self.discriminators = torch.nn.ModuleList()
+ for period in periods:
+ params = copy.deepcopy(discriminator_params)
+ params["period"] = period
+ self.discriminators += [HiFiGANPeriodDiscriminator(**params)]
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List: List of list of each discriminator outputs, which consists of each
+ layer output tensors.
+
+ """
+ outs = []
+ for f in self.discriminators:
+ outs += [f(x)]
+
+ return outs
+
+
+class HiFiGANScaleDiscriminator(torch.nn.Module):
+ """HiFi-GAN scale discriminator module."""
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ kernel_sizes: List[int] = [15, 41, 5, 3],
+ channels: int = 128,
+ max_downsample_channels: int = 1024,
+ max_groups: int = 16,
+ bias: int = True,
+ downsample_scales: List[int] = [2, 2, 4, 4, 1],
+ nonlinear_activation: str = "LeakyReLU",
+ nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
+ use_weight_norm: bool = True,
+ use_spectral_norm: bool = False,
+ ):
+ """Initilize HiFiGAN scale discriminator module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ kernel_sizes (List[int]): List of four kernel sizes. The first will be used
+ for the first conv layer, and the second is for downsampling part, and
+ the remaining two are for the last two output layers.
+ channels (int): Initial number of channels for conv layer.
+ max_downsample_channels (int): Maximum number of channels for downsampling
+ layers.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ downsample_scales (List[int]): List of downsampling scales.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
+ function.
+ use_weight_norm (bool): Whether to use weight norm. If set to true, it will
+ be applied to all of the conv layers.
+ use_spectral_norm (bool): Whether to use spectral norm. If set to true, it
+ will be applied to all of the conv layers.
+
+ """
+ super().__init__()
+ self.layers = torch.nn.ModuleList()
+
+ # check kernel size is valid
+ assert len(kernel_sizes) == 4
+ for ks in kernel_sizes:
+ assert ks % 2 == 1
+
+ # add first layer
+ self.layers += [
+ torch.nn.Sequential(
+ torch.nn.Conv1d(
+ in_channels,
+ channels,
+ # NOTE(kan-bayashi): Use always the same kernel size
+ kernel_sizes[0],
+ bias=bias,
+ padding=(kernel_sizes[0] - 1) // 2,
+ ),
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+ )
+ ]
+
+ # add downsample layers
+ in_chs = channels
+ out_chs = channels
+ # NOTE(kan-bayashi): Remove hard coding?
+ groups = 4
+ for downsample_scale in downsample_scales:
+ self.layers += [
+ torch.nn.Sequential(
+ torch.nn.Conv1d(
+ in_chs,
+ out_chs,
+ kernel_size=kernel_sizes[1],
+ stride=downsample_scale,
+ padding=(kernel_sizes[1] - 1) // 2,
+ groups=groups,
+ bias=bias,
+ ),
+ getattr(torch.nn, nonlinear_activation)(
+ **nonlinear_activation_params
+ ),
+ )
+ ]
+ in_chs = out_chs
+ # NOTE(kan-bayashi): Remove hard coding?
+ out_chs = min(in_chs * 2, max_downsample_channels)
+ # NOTE(kan-bayashi): Remove hard coding?
+ groups = min(groups * 4, max_groups)
+
+ # add final layers
+ out_chs = min(in_chs * 2, max_downsample_channels)
+ self.layers += [
+ torch.nn.Sequential(
+ torch.nn.Conv1d(
+ in_chs,
+ out_chs,
+ kernel_size=kernel_sizes[2],
+ stride=1,
+ padding=(kernel_sizes[2] - 1) // 2,
+ bias=bias,
+ ),
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+ )
+ ]
+ self.layers += [
+ torch.nn.Conv1d(
+ out_chs,
+ out_channels,
+ kernel_size=kernel_sizes[3],
+ stride=1,
+ padding=(kernel_sizes[3] - 1) // 2,
+ bias=bias,
+ ),
+ ]
+
+ if use_weight_norm and use_spectral_norm:
+ raise ValueError("Either use use_weight_norm or use_spectral_norm.")
+
+ # apply weight norm
+ self.use_weight_norm = use_weight_norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ # apply spectral norm
+ self.use_spectral_norm = use_spectral_norm
+ if use_spectral_norm:
+ self.apply_spectral_norm()
+
+ # backward compatibility
+ self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List[Tensor]: List of output tensors of each layer.
+
+ """
+ outs = []
+ for f in self.layers:
+ x = f(x)
+ outs += [x]
+
+ return outs
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv1d):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ def apply_spectral_norm(self):
+ """Apply spectral normalization module from all of the layers."""
+
+ def _apply_spectral_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv1d):
+ torch.nn.utils.spectral_norm(m)
+ logging.debug(f"Spectral norm is applied to {m}.")
+
+ self.apply(_apply_spectral_norm)
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m):
+ try:
+ logging.debug(f"Weight norm is removed from {m}.")
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def remove_spectral_norm(self):
+ """Remove spectral normalization module from all of the layers."""
+
+ def _remove_spectral_norm(m):
+ try:
+ logging.debug(f"Spectral norm is removed from {m}.")
+ torch.nn.utils.remove_spectral_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_spectral_norm)
+
+ def _load_state_dict_pre_hook(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ ):
+ """Fix the compatibility of weight / spectral normalization issue.
+
+ Some pretrained models are trained with configs that use weight / spectral
+ normalization, but actually, the norm is not applied. This causes the mismatch
+ of the parameters with configs. To solve this issue, when parameter mismatch
+ happens in loading pretrained model, we remove the norm from the current model.
+
+ See also:
+ - https://github.com/espnet/espnet/pull/5240
+ - https://github.com/espnet/espnet/pull/5249
+ - https://github.com/kan-bayashi/ParallelWaveGAN/pull/409
+
+ """
+ current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)]
+ if self.use_weight_norm and any(
+ [k.endswith("weight") for k in current_module_keys]
+ ):
+ logging.warning(
+ "It seems weight norm is not applied in the pretrained model but the"
+ " current model uses it. To keep the compatibility, we remove the norm"
+ " from the current model. This may cause unexpected behavior due to the"
+ " parameter mismatch in finetuning. To avoid this issue, please change"
+ " the following parameters in config to false:\n"
+ " - discriminator_params.follow_official_norm\n"
+ " - discriminator_params.scale_discriminator_params.use_weight_norm\n"
+ " - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
+ "\n"
+ "See also:\n"
+ " - https://github.com/espnet/espnet/pull/5240\n"
+ " - https://github.com/espnet/espnet/pull/5249"
+ )
+ self.remove_weight_norm()
+ self.use_weight_norm = False
+ for k in current_module_keys:
+ if k.endswith("weight_g") or k.endswith("weight_v"):
+ del state_dict[k]
+
+ if self.use_spectral_norm and any(
+ [k.endswith("weight") for k in current_module_keys]
+ ):
+ logging.warning(
+ "It seems spectral norm is not applied in the pretrained model but the"
+ " current model uses it. To keep the compatibility, we remove the norm"
+ " from the current model. This may cause unexpected behavior due to the"
+ " parameter mismatch in finetuning. To avoid this issue, please change"
+ " the following parameters in config to false:\n"
+ " - discriminator_params.follow_official_norm\n"
+ " - discriminator_params.scale_discriminator_params.use_weight_norm\n"
+ " - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
+ "\n"
+ "See also:\n"
+ " - https://github.com/espnet/espnet/pull/5240\n"
+ " - https://github.com/espnet/espnet/pull/5249"
+ )
+ self.remove_spectral_norm()
+ self.use_spectral_norm = False
+ for k in current_module_keys:
+ if (
+ k.endswith("weight_u")
+ or k.endswith("weight_v")
+ or k.endswith("weight_orig")
+ ):
+ del state_dict[k]
+
+
+class HiFiGANMultiScaleDiscriminator(torch.nn.Module):
+ """HiFi-GAN multi-scale discriminator module."""
+
+ def __init__(
+ self,
+ scales: int = 3,
+ downsample_pooling: str = "AvgPool1d",
+ # follow the official implementation setting
+ downsample_pooling_params: Dict[str, Any] = {
+ "kernel_size": 4,
+ "stride": 2,
+ "padding": 2,
+ },
+ discriminator_params: Dict[str, Any] = {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [15, 41, 5, 3],
+ "channels": 128,
+ "max_downsample_channels": 1024,
+ "max_groups": 16,
+ "bias": True,
+ "downsample_scales": [2, 2, 4, 4, 1],
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ },
+ follow_official_norm: bool = False,
+ ):
+ """Initilize HiFiGAN multi-scale discriminator module.
+
+ Args:
+ scales (int): Number of multi-scales.
+ downsample_pooling (str): Pooling module name for downsampling of the
+ inputs.
+ downsample_pooling_params (Dict[str, Any]): Parameters for the above pooling
+ module.
+ discriminator_params (Dict[str, Any]): Parameters for hifi-gan scale
+ discriminator module.
+ follow_official_norm (bool): Whether to follow the norm setting of the
+ official implementaion. The first discriminator uses spectral norm
+ and the other discriminators use weight norm.
+
+ """
+ super().__init__()
+ self.discriminators = torch.nn.ModuleList()
+
+ # add discriminators
+ for i in range(scales):
+ params = copy.deepcopy(discriminator_params)
+ if follow_official_norm:
+ if i == 0:
+ params["use_weight_norm"] = False
+ params["use_spectral_norm"] = True
+ else:
+ params["use_weight_norm"] = True
+ params["use_spectral_norm"] = False
+ self.discriminators += [HiFiGANScaleDiscriminator(**params)]
+ self.pooling = None
+ if scales > 1:
+ self.pooling = getattr(torch.nn, downsample_pooling)(
+ **downsample_pooling_params
+ )
+
+ def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List[List[torch.Tensor]]: List of list of each discriminator outputs,
+ which consists of eachlayer output tensors.
+
+ """
+ outs = []
+ for f in self.discriminators:
+ outs += [f(x)]
+ if self.pooling is not None:
+ x = self.pooling(x)
+
+ return outs
+
+
+class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module):
+ """HiFi-GAN multi-scale + multi-period discriminator module."""
+
+ def __init__(
+ self,
+ # Multi-scale discriminator related
+ scales: int = 3,
+ scale_downsample_pooling: str = "AvgPool1d",
+ scale_downsample_pooling_params: Dict[str, Any] = {
+ "kernel_size": 4,
+ "stride": 2,
+ "padding": 2,
+ },
+ scale_discriminator_params: Dict[str, Any] = {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [15, 41, 5, 3],
+ "channels": 128,
+ "max_downsample_channels": 1024,
+ "max_groups": 16,
+ "bias": True,
+ "downsample_scales": [2, 2, 4, 4, 1],
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ },
+ follow_official_norm: bool = True,
+ # Multi-period discriminator related
+ periods: List[int] = [2, 3, 5, 7, 11],
+ period_discriminator_params: Dict[str, Any] = {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [5, 3],
+ "channels": 32,
+ "downsample_scales": [3, 3, 3, 3, 1],
+ "max_downsample_channels": 1024,
+ "bias": True,
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ "use_weight_norm": True,
+ "use_spectral_norm": False,
+ },
+ ):
+ """Initilize HiFiGAN multi-scale + multi-period discriminator module.
+
+ Args:
+ scales (int): Number of multi-scales.
+ scale_downsample_pooling (str): Pooling module name for downsampling of the
+ inputs.
+ scale_downsample_pooling_params (dict): Parameters for the above pooling
+ module.
+ scale_discriminator_params (dict): Parameters for hifi-gan scale
+ discriminator module.
+ follow_official_norm (bool): Whether to follow the norm setting of the
+ official implementaion. The first discriminator uses spectral norm and
+ the other discriminators use weight norm.
+ periods (list): List of periods.
+ period_discriminator_params (dict): Parameters for hifi-gan period
+ discriminator module. The period parameter will be overwritten.
+
+ """
+ super().__init__()
+ self.msd = HiFiGANMultiScaleDiscriminator(
+ scales=scales,
+ downsample_pooling=scale_downsample_pooling,
+ downsample_pooling_params=scale_downsample_pooling_params,
+ discriminator_params=scale_discriminator_params,
+ follow_official_norm=follow_official_norm,
+ )
+ self.mpd = HiFiGANMultiPeriodDiscriminator(
+ periods=periods,
+ discriminator_params=period_discriminator_params,
+ )
+
+ def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List[List[Tensor]]: List of list of each discriminator outputs,
+ which consists of each layer output tensors. Multi scale and
+ multi period ones are concatenated.
+
+ """
+ msd_outs = self.msd(x)
+ mpd_outs = self.mpd(x)
+ return msd_outs + mpd_outs
diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py
new file mode 100755
index 000000000..91a35e360
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/infer.py
@@ -0,0 +1,233 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script performs model inference on test set.
+
+Usage:
+./vits/infer.py \
+ --epoch 1000 \
+ --exp-dir ./vits/exp \
+ --max-duration 500
+"""
+
+
+import argparse
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+from typing import List
+
+import k2
+import torch
+import torch.nn as nn
+import torchaudio
+
+from train import get_model, get_params
+from tokenizer import Tokenizer
+
+from icefall.checkpoint import load_checkpoint
+from icefall.utils import AttributeDict, setup_logger
+from tts_datamodule import LJSpeechTtsDataModule
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=1000,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="vits/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ return parser
+
+
+def infer_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ tokenizer: Tokenizer,
+) -> None:
+ """Decode dataset.
+ The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ tokenizer:
+ Used to convert text to phonemes.
+ """
+ # Background worker save audios to disk.
+ def _save_worker(
+ batch_size: int,
+ cut_ids: List[str],
+ audio: torch.Tensor,
+ audio_pred: torch.Tensor,
+ audio_lens: List[int],
+ audio_lens_pred: List[int],
+ ):
+ for i in range(batch_size):
+ torchaudio.save(
+ str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
+ audio[i:i + 1, :audio_lens[i]],
+ sample_rate=params.sampling_rate,
+ )
+ torchaudio.save(
+ str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
+ audio_pred[i:i + 1, :audio_lens_pred[i]],
+ sample_rate=params.sampling_rate,
+ )
+
+ device = next(model.parameters()).device
+ num_cuts = 0
+ log_interval = 5
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ futures = []
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ for batch_idx, batch in enumerate(dl):
+ batch_size = len(batch["tokens"])
+
+ tokens = batch["tokens"]
+ tokens = tokenizer.tokens_to_token_ids(tokens)
+ tokens = k2.RaggedTensor(tokens)
+ row_splits = tokens.shape.row_splits(1)
+ tokens_lens = row_splits[1:] - row_splits[:-1]
+ tokens = tokens.to(device)
+ tokens_lens = tokens_lens.to(device)
+ # tensor of shape (B, T)
+ tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
+
+ audio = batch["audio"]
+ audio_lens = batch["audio_lens"].tolist()
+ cut_ids = [cut.id for cut in batch["cut"]]
+
+ audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens)
+ audio_pred = audio_pred.detach().cpu()
+ # convert to samples
+ audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
+
+ futures.append(
+ executor.submit(
+ _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
+ )
+ )
+
+ num_cuts += batch_size
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ # return results
+ for f in futures:
+ f.result()
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LJSpeechTtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.suffix = f"epoch-{params.epoch}"
+
+ params.res_dir = params.exp_dir / "infer" / params.suffix
+ params.save_wav_dir = params.res_dir / "wav"
+ params.save_wav_dir.mkdir(parents=True, exist_ok=True)
+
+ setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
+ logging.info("Infer started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ tokenizer = Tokenizer(params.tokens)
+ params.blank_id = tokenizer.blank_id
+ params.oov_id = tokenizer.oov_id
+ params.vocab_size = tokenizer.vocab_size
+
+ logging.info(f"Device: {device}")
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+
+ model.to(device)
+ model.eval()
+
+ num_param_g = sum([p.numel() for p in model.generator.parameters()])
+ logging.info(f"Number of parameters in generator: {num_param_g}")
+ num_param_d = sum([p.numel() for p in model.discriminator.parameters()])
+ logging.info(f"Number of parameters in discriminator: {num_param_d}")
+ logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ ljspeech = LJSpeechTtsDataModule(args)
+
+ test_cuts = ljspeech.test_cuts()
+ test_dl = ljspeech.test_dataloaders(test_cuts)
+
+ infer_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ )
+
+ logging.info(f"Wav files are saved to {params.save_wav_dir}")
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/ljspeech/TTS/vits/loss.py b/egs/ljspeech/TTS/vits/loss.py
new file mode 100644
index 000000000..21aaad6e7
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/loss.py
@@ -0,0 +1,336 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""HiFiGAN-related loss modules.
+
+This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
+
+"""
+
+from typing import List, Tuple, Union
+
+import torch
+import torch.distributions as D
+import torch.nn.functional as F
+
+from lhotse.features.kaldi import Wav2LogFilterBank
+
+
+class GeneratorAdversarialLoss(torch.nn.Module):
+ """Generator adversarial loss module."""
+
+ def __init__(
+ self,
+ average_by_discriminators: bool = True,
+ loss_type: str = "mse",
+ ):
+ """Initialize GeneratorAversarialLoss module.
+
+ Args:
+ average_by_discriminators (bool): Whether to average the loss by
+ the number of discriminators.
+ loss_type (str): Loss type, "mse" or "hinge".
+
+ """
+ super().__init__()
+ self.average_by_discriminators = average_by_discriminators
+ assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
+ if loss_type == "mse":
+ self.criterion = self._mse_loss
+ else:
+ self.criterion = self._hinge_loss
+
+ def forward(
+ self,
+ outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
+ ) -> torch.Tensor:
+ """Calcualate generator adversarial loss.
+
+ Args:
+ outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
+ outputs, list of discriminator outputs, or list of list of discriminator
+ outputs..
+
+ Returns:
+ Tensor: Generator adversarial loss value.
+
+ """
+ if isinstance(outputs, (tuple, list)):
+ adv_loss = 0.0
+ for i, outputs_ in enumerate(outputs):
+ if isinstance(outputs_, (tuple, list)):
+ # NOTE(kan-bayashi): case including feature maps
+ outputs_ = outputs_[-1]
+ adv_loss += self.criterion(outputs_)
+ if self.average_by_discriminators:
+ adv_loss /= i + 1
+ else:
+ adv_loss = self.criterion(outputs)
+
+ return adv_loss
+
+ def _mse_loss(self, x):
+ return F.mse_loss(x, x.new_ones(x.size()))
+
+ def _hinge_loss(self, x):
+ return -x.mean()
+
+
+class DiscriminatorAdversarialLoss(torch.nn.Module):
+ """Discriminator adversarial loss module."""
+
+ def __init__(
+ self,
+ average_by_discriminators: bool = True,
+ loss_type: str = "mse",
+ ):
+ """Initialize DiscriminatorAversarialLoss module.
+
+ Args:
+ average_by_discriminators (bool): Whether to average the loss by
+ the number of discriminators.
+ loss_type (str): Loss type, "mse" or "hinge".
+
+ """
+ super().__init__()
+ self.average_by_discriminators = average_by_discriminators
+ assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
+ if loss_type == "mse":
+ self.fake_criterion = self._mse_fake_loss
+ self.real_criterion = self._mse_real_loss
+ else:
+ self.fake_criterion = self._hinge_fake_loss
+ self.real_criterion = self._hinge_real_loss
+
+ def forward(
+ self,
+ outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
+ outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Calcualate discriminator adversarial loss.
+
+ Args:
+ outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
+ outputs, list of discriminator outputs, or list of list of discriminator
+ outputs calculated from generator.
+ outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
+ outputs, list of discriminator outputs, or list of list of discriminator
+ outputs calculated from groundtruth.
+
+ Returns:
+ Tensor: Discriminator real loss value.
+ Tensor: Discriminator fake loss value.
+
+ """
+ if isinstance(outputs, (tuple, list)):
+ real_loss = 0.0
+ fake_loss = 0.0
+ for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
+ if isinstance(outputs_hat_, (tuple, list)):
+ # NOTE(kan-bayashi): case including feature maps
+ outputs_hat_ = outputs_hat_[-1]
+ outputs_ = outputs_[-1]
+ real_loss += self.real_criterion(outputs_)
+ fake_loss += self.fake_criterion(outputs_hat_)
+ if self.average_by_discriminators:
+ fake_loss /= i + 1
+ real_loss /= i + 1
+ else:
+ real_loss = self.real_criterion(outputs)
+ fake_loss = self.fake_criterion(outputs_hat)
+
+ return real_loss, fake_loss
+
+ def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return F.mse_loss(x, x.new_ones(x.size()))
+
+ def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return F.mse_loss(x, x.new_zeros(x.size()))
+
+ def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
+
+ def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
+
+
+class FeatureMatchLoss(torch.nn.Module):
+ """Feature matching loss module."""
+
+ def __init__(
+ self,
+ average_by_layers: bool = True,
+ average_by_discriminators: bool = True,
+ include_final_outputs: bool = False,
+ ):
+ """Initialize FeatureMatchLoss module.
+
+ Args:
+ average_by_layers (bool): Whether to average the loss by the number
+ of layers.
+ average_by_discriminators (bool): Whether to average the loss by
+ the number of discriminators.
+ include_final_outputs (bool): Whether to include the final output of
+ each discriminator for loss calculation.
+
+ """
+ super().__init__()
+ self.average_by_layers = average_by_layers
+ self.average_by_discriminators = average_by_discriminators
+ self.include_final_outputs = include_final_outputs
+
+ def forward(
+ self,
+ feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
+ feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
+ ) -> torch.Tensor:
+ """Calculate feature matching loss.
+
+ Args:
+ feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
+ discriminator outputs or list of discriminator outputs calcuated
+ from generator's outputs.
+ feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
+ discriminator outputs or list of discriminator outputs calcuated
+ from groundtruth..
+
+ Returns:
+ Tensor: Feature matching loss value.
+
+ """
+ feat_match_loss = 0.0
+ for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
+ feat_match_loss_ = 0.0
+ if not self.include_final_outputs:
+ feats_hat_ = feats_hat_[:-1]
+ feats_ = feats_[:-1]
+ for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
+ feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
+ if self.average_by_layers:
+ feat_match_loss_ /= j + 1
+ feat_match_loss += feat_match_loss_
+ if self.average_by_discriminators:
+ feat_match_loss /= i + 1
+
+ return feat_match_loss
+
+
+class MelSpectrogramLoss(torch.nn.Module):
+ """Mel-spectrogram loss."""
+
+ def __init__(
+ self,
+ sampling_rate: int = 22050,
+ frame_length: int = 1024, # in samples
+ frame_shift: int = 256, # in samples
+ n_mels: int = 80,
+ use_fft_mag: bool = True,
+ ):
+ super().__init__()
+ self.wav_to_mel = Wav2LogFilterBank(
+ sampling_rate=sampling_rate,
+ frame_length=frame_length / sampling_rate, # in second
+ frame_shift=frame_shift / sampling_rate, # in second
+ use_fft_mag=use_fft_mag,
+ num_filters=n_mels,
+ )
+
+ def forward(
+ self,
+ y_hat: torch.Tensor,
+ y: torch.Tensor,
+ return_mel: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
+ """Calculate Mel-spectrogram loss.
+
+ Args:
+ y_hat (Tensor): Generated waveform tensor (B, 1, T).
+ y (Tensor): Groundtruth waveform tensor (B, 1, T).
+ spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
+ (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
+ waveform.
+
+ Returns:
+ Tensor: Mel-spectrogram loss value.
+
+ """
+ mel_hat = self.wav_to_mel(y_hat.squeeze(1))
+ mel = self.wav_to_mel(y.squeeze(1))
+ mel_loss = F.l1_loss(mel_hat, mel)
+
+ if return_mel:
+ return mel_loss, (mel_hat, mel)
+
+ return mel_loss
+
+
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
+
+"""VITS-related loss modules.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+
+class KLDivergenceLoss(torch.nn.Module):
+ """KL divergence loss."""
+
+ def forward(
+ self,
+ z_p: torch.Tensor,
+ logs_q: torch.Tensor,
+ m_p: torch.Tensor,
+ logs_p: torch.Tensor,
+ z_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ """Calculate KL divergence loss.
+
+ Args:
+ z_p (Tensor): Flow hidden representation (B, H, T_feats).
+ logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
+ m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
+ logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
+ z_mask (Tensor): Mask tensor (B, 1, T_feats).
+
+ Returns:
+ Tensor: KL divergence loss.
+
+ """
+ z_p = z_p.float()
+ logs_q = logs_q.float()
+ m_p = m_p.float()
+ logs_p = logs_p.float()
+ z_mask = z_mask.float()
+ kl = logs_p - logs_q - 0.5
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
+ kl = torch.sum(kl * z_mask)
+ loss = kl / torch.sum(z_mask)
+
+ return loss
+
+
+class KLDivergenceLossWithoutFlow(torch.nn.Module):
+ """KL divergence loss without flow."""
+
+ def forward(
+ self,
+ m_q: torch.Tensor,
+ logs_q: torch.Tensor,
+ m_p: torch.Tensor,
+ logs_p: torch.Tensor,
+ ) -> torch.Tensor:
+ """Calculate KL divergence loss without flow.
+
+ Args:
+ m_q (Tensor): Posterior encoder projected mean (B, H, T_feats).
+ logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
+ m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
+ logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
+ """
+ posterior_norm = D.Normal(m_q, torch.exp(logs_q))
+ prior_norm = D.Normal(m_p, torch.exp(logs_p))
+ loss = D.kl_divergence(posterior_norm, prior_norm).mean()
+ return loss
diff --git a/egs/ljspeech/TTS/vits/monotonic_align/__init__.py b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py
new file mode 100644
index 000000000..2b35654f5
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py
@@ -0,0 +1,81 @@
+# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/__init__.py
+
+"""Maximum path calculation module.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+import warnings
+
+import numpy as np
+import torch
+from numba import njit, prange
+
+try:
+ from .core import maximum_path_c
+
+ is_cython_avalable = True
+except ImportError:
+ is_cython_avalable = False
+ warnings.warn(
+ "Cython version is not available. Fallback to 'EXPERIMETAL' numba version. "
+ "If you want to use the cython version, please build it as follows: "
+ "`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`"
+ )
+
+
+def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
+ """Calculate maximum path.
+
+ Args:
+ neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text).
+ attn_mask (Tensor): Attention mask (B, T_feats, T_text).
+
+ Returns:
+ Tensor: Maximum path tensor (B, T_feats, T_text).
+
+ """
+ device, dtype = neg_x_ent.device, neg_x_ent.dtype
+ neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32)
+ path = np.zeros(neg_x_ent.shape, dtype=np.int32)
+ t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32)
+ t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32)
+ if is_cython_avalable:
+ maximum_path_c(path, neg_x_ent, t_t_max, t_s_max)
+ else:
+ maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max)
+
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
+
+
+@njit
+def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf):
+ """Calculate a single maximum path with numba."""
+ index = t_x - 1
+ for y in range(t_y):
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
+ if x == y:
+ v_cur = max_neg_val
+ else:
+ v_cur = value[y - 1, x]
+ if x == 0:
+ if y == 0:
+ v_prev = 0.0
+ else:
+ v_prev = max_neg_val
+ else:
+ v_prev = value[y - 1, x - 1]
+ value[y, x] += max(v_prev, v_cur)
+
+ for y in range(t_y - 1, -1, -1):
+ path[y, index] = 1
+ if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
+ index = index - 1
+
+
+@njit(parallel=True)
+def maximum_path_numba(paths, values, t_ys, t_xs):
+ """Calculate batch maximum path with numba."""
+ for i in prange(paths.shape[0]):
+ maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i])
diff --git a/egs/ljspeech/TTS/vits/monotonic_align/core.pyx b/egs/ljspeech/TTS/vits/monotonic_align/core.pyx
new file mode 100644
index 000000000..c02c2d02e
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/monotonic_align/core.pyx
@@ -0,0 +1,51 @@
+# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/core.pyx
+
+"""Maximum path calculation module with cython optimization.
+
+This code is copied from https://github.com/jaywalnut310/vits and modifed code format.
+
+"""
+
+cimport cython
+
+from cython.parallel import prange
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
+ cdef int x
+ cdef int y
+ cdef float v_prev
+ cdef float v_cur
+ cdef float tmp
+ cdef int index = t_x - 1
+
+ for y in range(t_y):
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
+ if x == y:
+ v_cur = max_neg_val
+ else:
+ v_cur = value[y - 1, x]
+ if x == 0:
+ if y == 0:
+ v_prev = 0.0
+ else:
+ v_prev = max_neg_val
+ else:
+ v_prev = value[y - 1, x - 1]
+ value[y, x] += max(v_prev, v_cur)
+
+ for y in range(t_y - 1, -1, -1):
+ path[y, index] = 1
+ if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
+ index = index - 1
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
+ cdef int b = paths.shape[0]
+ cdef int i
+ for i in prange(b, nogil=True):
+ maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
diff --git a/egs/ljspeech/TTS/vits/monotonic_align/setup.py b/egs/ljspeech/TTS/vits/monotonic_align/setup.py
new file mode 100644
index 000000000..33d75e176
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/monotonic_align/setup.py
@@ -0,0 +1,31 @@
+# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py
+"""Setup cython code."""
+
+from Cython.Build import cythonize
+from setuptools import Extension, setup
+from setuptools.command.build_ext import build_ext as _build_ext
+
+
+class build_ext(_build_ext):
+ """Overwrite build_ext."""
+
+ def finalize_options(self):
+ """Prevent numpy from thinking it is still in its setup process."""
+ _build_ext.finalize_options(self)
+ __builtins__.__NUMPY_SETUP__ = False
+ import numpy
+
+ self.include_dirs.append(numpy.get_include())
+
+
+exts = [
+ Extension(
+ name="core",
+ sources=["core.pyx"],
+ )
+]
+setup(
+ name="monotonic_align",
+ ext_modules=cythonize(exts, language_level=3),
+ cmdclass={"build_ext": build_ext},
+)
diff --git a/egs/ljspeech/TTS/vits/posterior_encoder.py b/egs/ljspeech/TTS/vits/posterior_encoder.py
new file mode 100644
index 000000000..6b8a5be52
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/posterior_encoder.py
@@ -0,0 +1,117 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Posterior encoder module in VITS.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+from typing import Optional, Tuple
+
+import torch
+
+from icefall.utils import make_pad_mask
+from wavenet import WaveNet, Conv1d
+
+
+class PosteriorEncoder(torch.nn.Module):
+ """Posterior encoder module in VITS.
+
+ This is a module of posterior encoder described in `Conditional Variational
+ Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
+
+ .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
+ Text-to-Speech`: https://arxiv.org/abs/2006.04558
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 513,
+ out_channels: int = 192,
+ hidden_channels: int = 192,
+ kernel_size: int = 5,
+ layers: int = 16,
+ stacks: int = 1,
+ base_dilation: int = 1,
+ global_channels: int = -1,
+ dropout_rate: float = 0.0,
+ bias: bool = True,
+ use_weight_norm: bool = True,
+ ):
+ """Initilialize PosteriorEncoder module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ hidden_channels (int): Number of hidden channels.
+ kernel_size (int): Kernel size in WaveNet.
+ layers (int): Number of layers of WaveNet.
+ stacks (int): Number of repeat stacking of WaveNet.
+ base_dilation (int): Base dilation factor.
+ global_channels (int): Number of global conditioning channels.
+ dropout_rate (float): Dropout rate.
+ bias (bool): Whether to use bias parameters in conv.
+ use_weight_norm (bool): Whether to apply weight norm.
+
+ """
+ super().__init__()
+
+ # define modules
+ self.input_conv = Conv1d(in_channels, hidden_channels, 1)
+ self.encoder = WaveNet(
+ in_channels=-1,
+ out_channels=-1,
+ kernel_size=kernel_size,
+ layers=layers,
+ stacks=stacks,
+ base_dilation=base_dilation,
+ residual_channels=hidden_channels,
+ aux_channels=-1,
+ gate_channels=hidden_channels * 2,
+ skip_channels=hidden_channels,
+ global_channels=global_channels,
+ dropout_rate=dropout_rate,
+ bias=bias,
+ use_weight_norm=use_weight_norm,
+ use_first_conv=False,
+ use_last_conv=False,
+ scale_residual=False,
+ scale_skip_connect=True,
+ )
+ self.proj = Conv1d(hidden_channels, out_channels * 2, 1)
+
+ def forward(
+ self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, in_channels, T_feats).
+ x_lengths (Tensor): Length tensor (B,).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+
+ Returns:
+ Tensor: Encoded hidden representation tensor (B, out_channels, T_feats).
+ Tensor: Projected mean tensor (B, out_channels, T_feats).
+ Tensor: Projected scale tensor (B, out_channels, T_feats).
+ Tensor: Mask tensor for input tensor (B, 1, T_feats).
+
+ """
+ x_mask = (
+ (~make_pad_mask(x_lengths))
+ .unsqueeze(1)
+ .to(
+ dtype=x.dtype,
+ device=x.device,
+ )
+ )
+ x = self.input_conv(x) * x_mask
+ x = self.encoder(x, x_mask, g=g)
+ stats = self.proj(x) * x_mask
+ m, logs = stats.split(stats.size(1) // 2, dim=1)
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
+
+ return z, m, logs, x_mask
diff --git a/egs/ljspeech/TTS/vits/residual_coupling.py b/egs/ljspeech/TTS/vits/residual_coupling.py
new file mode 100644
index 000000000..2d6807cb7
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/residual_coupling.py
@@ -0,0 +1,229 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Residual affine coupling modules in VITS.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+from typing import Optional, Tuple, Union
+
+import torch
+
+from flow import FlipFlow
+from wavenet import WaveNet
+
+
+class ResidualAffineCouplingBlock(torch.nn.Module):
+ """Residual affine coupling block module.
+
+ This is a module of residual affine coupling block, which used as "Flow" in
+ `Conditional Variational Autoencoder with Adversarial Learning for End-to-End
+ Text-to-Speech`_.
+
+ .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
+ Text-to-Speech`: https://arxiv.org/abs/2006.04558
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 192,
+ hidden_channels: int = 192,
+ flows: int = 4,
+ kernel_size: int = 5,
+ base_dilation: int = 1,
+ layers: int = 4,
+ global_channels: int = -1,
+ dropout_rate: float = 0.0,
+ use_weight_norm: bool = True,
+ bias: bool = True,
+ use_only_mean: bool = True,
+ ):
+ """Initilize ResidualAffineCouplingBlock module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ hidden_channels (int): Number of hidden channels.
+ flows (int): Number of flows.
+ kernel_size (int): Kernel size for WaveNet.
+ base_dilation (int): Base dilation factor for WaveNet.
+ layers (int): Number of layers of WaveNet.
+ stacks (int): Number of stacks of WaveNet.
+ global_channels (int): Number of global channels.
+ dropout_rate (float): Dropout rate.
+ use_weight_norm (bool): Whether to use weight normalization in WaveNet.
+ bias (bool): Whether to use bias paramters in WaveNet.
+ use_only_mean (bool): Whether to estimate only mean.
+
+ """
+ super().__init__()
+
+ self.flows = torch.nn.ModuleList()
+ for i in range(flows):
+ self.flows += [
+ ResidualAffineCouplingLayer(
+ in_channels=in_channels,
+ hidden_channels=hidden_channels,
+ kernel_size=kernel_size,
+ base_dilation=base_dilation,
+ layers=layers,
+ stacks=1,
+ global_channels=global_channels,
+ dropout_rate=dropout_rate,
+ use_weight_norm=use_weight_norm,
+ bias=bias,
+ use_only_mean=use_only_mean,
+ )
+ ]
+ self.flows += [FlipFlow()]
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ g: Optional[torch.Tensor] = None,
+ inverse: bool = False,
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, in_channels, T).
+ x_lengths (Tensor): Length tensor (B,).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+ inverse (bool): Whether to inverse the flow.
+
+ Returns:
+ Tensor: Output tensor (B, in_channels, T).
+
+ """
+ if not inverse:
+ for flow in self.flows:
+ x, _ = flow(x, x_mask, g=g, inverse=inverse)
+ else:
+ for flow in reversed(self.flows):
+ x = flow(x, x_mask, g=g, inverse=inverse)
+ return x
+
+
+class ResidualAffineCouplingLayer(torch.nn.Module):
+ """Residual affine coupling layer."""
+
+ def __init__(
+ self,
+ in_channels: int = 192,
+ hidden_channels: int = 192,
+ kernel_size: int = 5,
+ base_dilation: int = 1,
+ layers: int = 5,
+ stacks: int = 1,
+ global_channels: int = -1,
+ dropout_rate: float = 0.0,
+ use_weight_norm: bool = True,
+ bias: bool = True,
+ use_only_mean: bool = True,
+ ):
+ """Initialzie ResidualAffineCouplingLayer module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ hidden_channels (int): Number of hidden channels.
+ kernel_size (int): Kernel size for WaveNet.
+ base_dilation (int): Base dilation factor for WaveNet.
+ layers (int): Number of layers of WaveNet.
+ stacks (int): Number of stacks of WaveNet.
+ global_channels (int): Number of global channels.
+ dropout_rate (float): Dropout rate.
+ use_weight_norm (bool): Whether to use weight normalization in WaveNet.
+ bias (bool): Whether to use bias paramters in WaveNet.
+ use_only_mean (bool): Whether to estimate only mean.
+
+ """
+ assert in_channels % 2 == 0, "in_channels should be divisible by 2"
+ super().__init__()
+ self.half_channels = in_channels // 2
+ self.use_only_mean = use_only_mean
+
+ # define modules
+ self.input_conv = torch.nn.Conv1d(
+ self.half_channels,
+ hidden_channels,
+ 1,
+ )
+ self.encoder = WaveNet(
+ in_channels=-1,
+ out_channels=-1,
+ kernel_size=kernel_size,
+ layers=layers,
+ stacks=stacks,
+ base_dilation=base_dilation,
+ residual_channels=hidden_channels,
+ aux_channels=-1,
+ gate_channels=hidden_channels * 2,
+ skip_channels=hidden_channels,
+ global_channels=global_channels,
+ dropout_rate=dropout_rate,
+ bias=bias,
+ use_weight_norm=use_weight_norm,
+ use_first_conv=False,
+ use_last_conv=False,
+ scale_residual=False,
+ scale_skip_connect=True,
+ )
+ if use_only_mean:
+ self.proj = torch.nn.Conv1d(
+ hidden_channels,
+ self.half_channels,
+ 1,
+ )
+ else:
+ self.proj = torch.nn.Conv1d(
+ hidden_channels,
+ self.half_channels * 2,
+ 1,
+ )
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ g: Optional[torch.Tensor] = None,
+ inverse: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, in_channels, T).
+ x_lengths (Tensor): Length tensor (B,).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+ inverse (bool): Whether to inverse the flow.
+
+ Returns:
+ Tensor: Output tensor (B, in_channels, T).
+ Tensor: Log-determinant tensor for NLL (B,) if not inverse.
+
+ """
+ xa, xb = x.split(x.size(1) // 2, dim=1)
+ h = self.input_conv(xa) * x_mask
+ h = self.encoder(h, x_mask, g=g)
+ stats = self.proj(h) * x_mask
+ if not self.use_only_mean:
+ m, logs = stats.split(stats.size(1) // 2, dim=1)
+ else:
+ m = stats
+ logs = torch.zeros_like(m)
+
+ if not inverse:
+ xb = m + xb * torch.exp(logs) * x_mask
+ x = torch.cat([xa, xb], 1)
+ logdet = torch.sum(logs, [1, 2])
+ return x, logdet
+ else:
+ xb = (xb - m) * torch.exp(-logs) * x_mask
+ x = torch.cat([xa, xb], 1)
+ return x
diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py
new file mode 100755
index 000000000..8acca7c02
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/test_onnx.py
@@ -0,0 +1,123 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script is used to test the exported onnx model by vits/export-onnx.py
+
+Use the onnx model to generate a wav:
+./vits/test_onnx.py \
+ --model-filename vits/exp/vits-epoch-1000.onnx \
+ --tokens data/tokens.txt
+"""
+
+
+import argparse
+import logging
+import onnxruntime as ort
+import torch
+import torchaudio
+
+from tokenizer import Tokenizer
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--model-filename",
+ type=str,
+ required=True,
+ help="Path to the onnx model.",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ return parser
+
+
+class OnnxModel:
+ def __init__(self, model_filename: str):
+ session_opts = ort.SessionOptions()
+ session_opts.inter_op_num_threads = 1
+ session_opts.intra_op_num_threads = 4
+
+ self.session_opts = session_opts
+
+ self.model = ort.InferenceSession(
+ model_filename,
+ sess_options=self.session_opts,
+ providers=["CPUExecutionProvider"],
+ )
+ logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
+
+ def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ tokens:
+ A 1-D tensor of shape (1, T)
+ Returns:
+ A tensor of shape (1, T')
+ """
+ noise_scale = torch.tensor([0.667], dtype=torch.float32)
+ noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
+ alpha = torch.tensor([1.0], dtype=torch.float32)
+
+ out = self.model.run(
+ [
+ self.model.get_outputs()[0].name,
+ ],
+ {
+ self.model.get_inputs()[0].name: tokens.numpy(),
+ self.model.get_inputs()[1].name: tokens_lens.numpy(),
+ self.model.get_inputs()[2].name: noise_scale.numpy(),
+ self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
+ self.model.get_inputs()[4].name: alpha.numpy(),
+ },
+ )[0]
+ return torch.from_numpy(out)
+
+
+def main():
+ args = get_parser().parse_args()
+
+ tokenizer = Tokenizer(args.tokens)
+
+ logging.info("About to create onnx model")
+ model = OnnxModel(args.model_filename)
+
+ text = "I went there to see the land, the people and how their system works, end quote."
+ tokens = tokenizer.texts_to_token_ids([text])
+ tokens = torch.tensor(tokens) # (1, T)
+ tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
+ audio = model(tokens, tokens_lens) # (1, T')
+
+ torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
+ logging.info("Saved to test_onnx.wav")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py
new file mode 100644
index 000000000..9f337e45b
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/text_encoder.py
@@ -0,0 +1,662 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Text encoder module in VITS.
+
+This code is based on
+ - https://github.com/jaywalnut310/vits
+ - https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py
+ - https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py
+"""
+
+import copy
+import math
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor, nn
+
+from icefall.utils import is_jit_tracing, make_pad_mask
+
+
+class TextEncoder(torch.nn.Module):
+ """Text encoder module in VITS.
+
+ This is a module of text encoder described in `Conditional Variational Autoencoder
+ with Adversarial Learning for End-to-End Text-to-Speech`.
+ """
+
+ def __init__(
+ self,
+ vocabs: int,
+ d_model: int = 192,
+ num_heads: int = 2,
+ dim_feedforward: int = 768,
+ cnn_module_kernel: int = 5,
+ num_layers: int = 6,
+ dropout: float = 0.1,
+ ):
+ """Initialize TextEncoder module.
+
+ Args:
+ vocabs (int): Vocabulary size.
+ d_model (int): attention dimension
+ num_heads (int): number of attention heads
+ dim_feedforward (int): feedforward dimention
+ cnn_module_kernel (int): convolution kernel size
+ num_layers (int): number of encoder layers
+ dropout (float): dropout rate
+ """
+ super().__init__()
+ self.d_model = d_model
+
+ # define modules
+ self.emb = torch.nn.Embedding(vocabs, d_model)
+ torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5)
+
+ # We use conformer as text encoder
+ self.encoder = Transformer(
+ d_model=d_model,
+ num_heads=num_heads,
+ dim_feedforward=dim_feedforward,
+ cnn_module_kernel=cnn_module_kernel,
+ num_layers=num_layers,
+ dropout=dropout,
+ )
+
+ self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input index tensor (B, T_text).
+ x_lengths (Tensor): Length tensor (B,).
+
+ Returns:
+ Tensor: Encoded hidden representation (B, attention_dim, T_text).
+ Tensor: Projected mean tensor (B, attention_dim, T_text).
+ Tensor: Projected scale tensor (B, attention_dim, T_text).
+ Tensor: Mask tensor for input tensor (B, 1, T_text).
+
+ """
+ # (B, T_text, embed_dim)
+ x = self.emb(x) * math.sqrt(self.d_model)
+
+ assert x.size(1) == x_lengths.max().item()
+
+ # (B, T_text)
+ pad_mask = make_pad_mask(x_lengths)
+
+ # encoder assume the channel last (B, T_text, embed_dim)
+ x = self.encoder(x, key_padding_mask=pad_mask)
+
+ # convert the channel first (B, embed_dim, T_text)
+ x = x.transpose(1, 2)
+ non_pad_mask = (~pad_mask).unsqueeze(1)
+ stats = self.proj(x) * non_pad_mask
+ m, logs = stats.split(stats.size(1) // 2, dim=1)
+
+ return x, m, logs, non_pad_mask
+
+
+class Transformer(nn.Module):
+ """
+ Args:
+ d_model (int): attention dimension
+ num_heads (int): number of attention heads
+ dim_feedforward (int): feedforward dimention
+ cnn_module_kernel (int): convolution kernel size
+ num_layers (int): number of encoder layers
+ dropout (float): dropout rate
+ """
+
+ def __init__(
+ self,
+ d_model: int = 192,
+ num_heads: int = 2,
+ dim_feedforward: int = 768,
+ cnn_module_kernel: int = 5,
+ num_layers: int = 6,
+ dropout: float = 0.1,
+ ) -> None:
+ super().__init__()
+
+ self.num_layers = num_layers
+ self.d_model = d_model
+
+ self.encoder_pos = RelPositionalEncoding(d_model, dropout)
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model=d_model,
+ num_heads=num_heads,
+ dim_feedforward=dim_feedforward,
+ cnn_module_kernel=cnn_module_kernel,
+ dropout=dropout,
+ )
+ self.encoder = TransformerEncoder(encoder_layer, num_layers)
+ self.after_norm = nn.LayerNorm(d_model)
+
+ def forward(
+ self, x: Tensor, key_padding_mask: Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (batch_size, seq_len, feature_dim).
+ lengths:
+ A tensor of shape (batch_size,) containing the number of frames in
+ `x` before padding.
+ """
+ x, pos_emb = self.encoder_pos(x)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ x = self.encoder(
+ x, pos_emb, key_padding_mask=key_padding_mask
+ ) # (T, N, C)
+
+ x = self.after_norm(x)
+
+ x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+ return x
+
+
+class TransformerEncoderLayer(nn.Module):
+ """
+ TransformerEncoderLayer is made up of self-attn and feedforward.
+
+ Args:
+ d_model: the number of expected features in the input.
+ num_heads: the number of heads in the multi-head attention models.
+ dim_feedforward: the dimension of the feed-forward network model.
+ dropout: the dropout value (default=0.1).
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ num_heads: int,
+ dim_feedforward: int,
+ cnn_module_kernel: int,
+ dropout: float = 0.1,
+ ) -> None:
+ super(TransformerEncoderLayer, self).__init__()
+
+ self.feed_forward_macaron = nn.Sequential(
+ nn.Linear(d_model, dim_feedforward),
+ Swish(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_feedforward, d_model),
+ )
+
+ self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout)
+
+ self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
+
+ self.feed_forward = nn.Sequential(
+ nn.Linear(d_model, dim_feedforward),
+ Swish(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_feedforward, d_model),
+ )
+
+ self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
+ self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
+ self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
+ self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
+ self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
+
+ self.ff_scale = 0.5
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Pass the input through the transformer encoder layer.
+
+ Args:
+ src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
+ pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
+ key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
+ """
+ # macaron style feed-forward module
+ src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src)))
+
+ # multi-head self-attention module
+ src_attn = self.self_attn(
+ self.norm_mha(src),
+ pos_emb=pos_emb,
+ key_padding_mask=key_padding_mask,
+ )
+ src = src + self.dropout(src_attn)
+
+ # convolution module
+ src = src + self.dropout(self.conv_module(self.norm_conv(src)))
+
+ # feed-forward module
+ src = src + self.dropout(self.feed_forward(self.norm_ff(src)))
+
+ src = self.norm_final(src)
+
+ return src
+
+
+class TransformerEncoder(nn.Module):
+ r"""TransformerEncoder is a stack of N encoder layers
+
+ Args:
+ encoder_layer: an instance of the TransformerEncoderLayer class.
+ num_layers: the number of sub-encoder-layers in the encoder.
+ """
+
+ def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
+ super().__init__()
+
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+ )
+ self.num_layers = num_layers
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
+ pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
+ key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
+ """
+ output = src
+
+ for layer_index, mod in enumerate(self.layers):
+ output = mod(
+ output,
+ pos_emb,
+ key_padding_mask=key_padding_mask,
+ )
+
+ return output
+
+
+class RelPositionalEncoding(torch.nn.Module):
+ """Relative positional encoding module.
+
+ See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
+
+ Args:
+ d_model: Embedding dimension.
+ dropout_rate: Dropout rate.
+ max_len: Maximum input length.
+
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+ """Construct an PositionalEncoding object."""
+ super(RelPositionalEncoding, self).__init__()
+
+ self.d_model = d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+ def extend_pe(self, x: Tensor) -> None:
+ """Reset the positional encodings."""
+ x_size = x.size(1)
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(1) >= x_size * 2 - 1:
+ # Note: TorchScript doesn't implement operator== for torch.Device
+ if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` means to the position of query vector and `j` means the
+ # position of key vector. We use position relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]:
+ """Add positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
+ """
+ self.extend_pe(x)
+ x = x * self.xscale
+ pos_emb = self.pe[
+ :,
+ self.pe.size(1) // 2
+ - x.size(1)
+ + 1 : self.pe.size(1) // 2 # noqa E203
+ + x.size(1),
+ ]
+ return self.dropout(x), self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttention(nn.Module):
+ r"""Multi-Head Attention layer with relative position encoding
+
+ See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+
+ Args:
+ embed_dim: total dimension of the model.
+ num_heads: parallel attention heads.
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ ) -> None:
+ super(RelPositionMultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert (
+ self.head_dim * num_heads == self.embed_dim
+ ), "embed_dim must be divisible by num_heads"
+
+ self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+
+ # linear transformation for positional encoding.
+ self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+
+ self._reset_parameters()
+
+ def _reset_parameters(self) -> None:
+ nn.init.xavier_uniform_(self.in_proj.weight)
+ nn.init.constant_(self.in_proj.bias, 0.0)
+ nn.init.constant_(self.out_proj.bias, 0.0)
+
+ nn.init.xavier_uniform_(self.pos_bias_u)
+ nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x: Tensor) -> Tensor:
+ """Compute relative positional encoding.
+
+ Args:
+ x: Input tensor (batch, head, seq_len, 2*seq_len-1).
+
+ Returns:
+ Tensor: tensor of shape (batch, head, seq_len, seq_len)
+ """
+ (batch_size, num_heads, seq_len, n) = x.shape
+
+ if not is_jit_tracing():
+ assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1"
+
+ if is_jit_tracing():
+ rows = torch.arange(start=seq_len - 1, end=-1, step=-1)
+ cols = torch.arange(seq_len)
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+ indexes = rows + cols
+
+ x = x.reshape(-1, n)
+ x = torch.gather(x, dim=1, index=indexes)
+ x = x.reshape(batch_size, num_heads, seq_len, seq_len)
+ return x
+ else:
+ # Note: TorchScript requires explicit arg for stride()
+ batch_stride = x.stride(0)
+ head_stride = x.stride(1)
+ time_stride = x.stride(2)
+ n_stride = x.stride(3)
+ return x.as_strided(
+ (batch_size, num_heads, seq_len, seq_len),
+ (batch_stride, head_stride, time_stride - n_stride, n_stride),
+ storage_offset=n_stride * (seq_len - 1),
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Args:
+ x: Input tensor of shape (seq_len, batch_size, embed_dim)
+ pos_emb: Positional embedding tensor, (1, 2*seq_len-1, pos_dim)
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. This is an binary mask. When the value is True,
+ the corresponding value on the attention layer will be filled with -inf.
+ Its shape is (batch_size, seq_len).
+
+ Outputs:
+ A tensor of shape (seq_len, batch_size, embed_dim).
+ """
+ seq_len, batch_size, _ = x.shape
+ scaling = float(self.head_dim) ** -0.5
+
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
+
+ q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
+ k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
+ v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
+
+ q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
+
+ p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim)
+ # (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
+ p = p.permute(0, 2, 3, 1)
+
+ # (batch_size, num_head, seq_len, head_dim)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+ k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
+ matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len)
+
+ # compute matrix b and matrix d
+ matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1)
+ matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len)
+
+ # (batch_size, num_head, seq_len, seq_len)
+ attn_output_weights = (matrix_ac + matrix_bd) * scaling
+ attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (batch_size, seq_len)
+ attn_output_weights = attn_output_weights.view(
+ batch_size, self.num_heads, seq_len, seq_len
+ )
+ attn_output_weights = attn_output_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ float("-inf"),
+ )
+ attn_output_weights = attn_output_weights.view(
+ batch_size * self.num_heads, seq_len, seq_len
+ )
+
+ attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
+ attn_output_weights = nn.functional.dropout(
+ attn_output_weights, p=self.dropout, training=self.training
+ )
+
+ # (batch_size * num_head, seq_len, head_dim)
+ attn_output = torch.bmm(attn_output_weights, v)
+ assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim)
+
+ attn_output = (
+ attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim)
+ )
+ # (seq_len, batch_size, embed_dim)
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model.
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ bias (bool): Whether to use bias in conv layers (default=True).
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ bias: bool = True,
+ ) -> None:
+ """Construct an ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+
+ padding = (kernel_size - 1) // 2
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ bias=bias,
+ )
+ self.norm = nn.LayerNorm(channels)
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = Swish()
+
+ def forward(
+ self,
+ x: Tensor,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Tensor]:
+ """Compute convolution module.
+
+ Args:
+ x: Input tensor (#time, batch, channels).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+
+ Returns:
+ Tensor: Output tensor (#time, batch, channels).
+
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channels, time)
+ x = nn.functional.glu(x, dim=1) # (batch, channels, time)
+
+ # 1D Depthwise Conv
+ if src_key_padding_mask is not None:
+ x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+ x = self.depthwise_conv(x)
+ # x is (batch, channels, time)
+ x = x.permute(0, 2, 1)
+ x = self.norm(x)
+ x = x.permute(0, 2, 1)
+
+ x = self.activation(x)
+
+ x = self.pointwise_conv2(x) # (batch, channel, time)
+
+ return x.permute(2, 0, 1)
+
+
+class Swish(nn.Module):
+ """Construct an Swish object."""
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swich activation function."""
+ return x * torch.sigmoid(x)
+
+
+def _test_text_encoder():
+ vocabs = 500
+ d_model = 192
+ batch_size = 5
+ seq_len = 100
+
+ m = TextEncoder(vocabs=vocabs, d_model=d_model)
+ x, m, logs, mask = m(
+ x=torch.randint(low=0, high=vocabs, size=(batch_size, seq_len)),
+ x_lengths=torch.full((batch_size,), seq_len),
+ )
+ print(x.shape, m.shape, logs.shape, mask.shape)
+
+
+if __name__ == "__main__":
+ _test_text_encoder()
diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py
new file mode 100644
index 000000000..0678b26fe
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/tokenizer.py
@@ -0,0 +1,106 @@
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List
+
+import g2p_en
+import tacotron_cleaner.cleaners
+from utils import intersperse
+
+
+class Tokenizer(object):
+ def __init__(self, tokens: str):
+ """
+ Args:
+ tokens: the file that maps tokens to ids
+ """
+ # Parse token file
+ self.token2id: Dict[str, int] = {}
+ with open(tokens, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ info = line.rstrip().split()
+ if len(info) == 1:
+ # case of space
+ token = " "
+ id = int(info[0])
+ else:
+ token, id = info[0], int(info[1])
+ self.token2id[token] = id
+
+ self.blank_id = self.token2id[""]
+ self.oov_id = self.token2id[""]
+ self.vocab_size = len(self.token2id)
+
+ self.g2p = g2p_en.G2p()
+
+ def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
+ """
+ Args:
+ texts:
+ A list of transcripts.
+ intersperse_blank:
+ Whether to intersperse blanks in the token sequence.
+
+ Returns:
+ Return a list of token id list [utterance][token_id]
+ """
+ token_ids_list = []
+
+ for text in texts:
+ # Text normalization
+ text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
+ # Convert to phonemes
+ tokens = self.g2p(text)
+ token_ids = []
+ for t in tokens:
+ if t in self.token2id:
+ token_ids.append(self.token2id[t])
+ else:
+ token_ids.append(self.oov_id)
+
+ if intersperse_blank:
+ token_ids = intersperse(token_ids, self.blank_id)
+
+ token_ids_list.append(token_ids)
+
+ return token_ids_list
+
+ def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True):
+ """
+ Args:
+ tokens_list:
+ A list of token list, each corresponding to one utterance.
+ intersperse_blank:
+ Whether to intersperse blanks in the token sequence.
+
+ Returns:
+ Return a list of token id list [utterance][token_id]
+ """
+ token_ids_list = []
+
+ for tokens in tokens_list:
+ token_ids = []
+ for t in tokens:
+ if t in self.token2id:
+ token_ids.append(self.token2id[t])
+ else:
+ token_ids.append(self.oov_id)
+
+ if intersperse_blank:
+ token_ids = intersperse(token_ids, self.blank_id)
+ token_ids_list.append(token_ids)
+
+ return token_ids_list
diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py
new file mode 100755
index 000000000..eb43a4cc9
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/train.py
@@ -0,0 +1,893 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import numpy as np
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from lhotse.cut import Cut
+from lhotse.utils import fix_random_seed
+from torch.optim import Optimizer
+from torch.cuda.amp import GradScaler, autocast
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import AttributeDict, setup_logger, str2bool
+
+from tokenizer import Tokenizer
+from tts_datamodule import LJSpeechTtsDataModule
+from utils import MetricsTracker, plot_feature, save_checkpoint
+from vits import VITS
+
+LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=1000,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="vits/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ parser.add_argument(
+ "--lr", type=float, default=2.0e-4, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=20,
+ help="""Save checkpoint after processing this number of epochs"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.cur_epoch % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
+ Since it will take around 1000 epochs, we suggest using a large
+ save_every_n to save disk space.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ # training params
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": -1, # 0
+ "log_interval": 50,
+ "valid_interval": 200,
+ "env_info": get_env_info(),
+ "sampling_rate": 22050,
+ "frame_shift": 256,
+ "frame_length": 1024,
+ "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
+ "n_mels": 80,
+ "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
+ "lambda_mel": 45.0, # loss scaling coefficient for Mel loss
+ "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
+ "lambda_dur": 1.0, # loss scaling coefficient for duration loss
+ "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss
+ }
+ )
+
+ return params
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict, model: nn.Module
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(filename, model=model)
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ return saved_params
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ mel_loss_params = {
+ "n_mels": params.n_mels,
+ "frame_length": params.frame_length,
+ "frame_shift": params.frame_shift,
+ }
+ model = VITS(
+ vocab_size=params.vocab_size,
+ feature_dim=params.feature_dim,
+ sampling_rate=params.sampling_rate,
+ mel_loss_params=mel_loss_params,
+ lambda_adv=params.lambda_adv,
+ lambda_mel=params.lambda_mel,
+ lambda_feat_match=params.lambda_feat_match,
+ lambda_dur=params.lambda_dur,
+ lambda_kl=params.lambda_kl,
+ )
+ return model
+
+
+def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
+ """Parse batch data"""
+ audio = batch["audio"].to(device)
+ features = batch["features"].to(device)
+ audio_lens = batch["audio_lens"].to(device)
+ features_lens = batch["features_lens"].to(device)
+ tokens = batch["tokens"]
+
+ tokens = tokenizer.tokens_to_token_ids(tokens)
+ tokens = k2.RaggedTensor(tokens)
+ row_splits = tokens.shape.row_splits(1)
+ tokens_lens = row_splits[1:] - row_splits[:-1]
+ tokens = tokens.to(device)
+ tokens_lens = tokens_lens.to(device)
+ # a tensor of shape (B, T)
+ tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
+
+ return audio, audio_lens, features, features_lens, tokens, tokens_lens
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ tokenizer: Tokenizer,
+ optimizer_g: Optimizer,
+ optimizer_d: Optimizer,
+ scheduler_g: LRSchedulerType,
+ scheduler_d: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ tokenizer:
+ Used to convert text to phonemes.
+ optimizer_g:
+ The optimizer for generator.
+ optimizer_d:
+ The optimizer for discriminator.
+ scheduler_g:
+ The learning rate scheduler for generator, we call step() every epoch.
+ scheduler_d:
+ The learning rate scheduler for discriminator, we call step() every epoch.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations in one epoch
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ params=params,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ scheduler_g=scheduler_g,
+ scheduler_d=scheduler_d,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+
+ batch_size = len(batch["tokens"])
+ audio, audio_lens, features, features_lens, tokens, tokens_lens = \
+ prepare_input(batch, tokenizer, device)
+
+ loss_info = MetricsTracker()
+ loss_info['samples'] = batch_size
+
+ try:
+ with autocast(enabled=params.use_fp16):
+ # forward discriminator
+ loss_d, stats_d = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=False,
+ )
+ for k, v in stats_d.items():
+ loss_info[k] = v * batch_size
+ # update discriminator
+ optimizer_d.zero_grad()
+ scaler.scale(loss_d).backward()
+ scaler.step(optimizer_d)
+
+ with autocast(enabled=params.use_fp16):
+ # forward generator
+ loss_g, stats_g = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=True,
+ return_sample=params.batch_idx_train % params.log_interval == 0,
+ )
+ for k, v in stats_g.items():
+ if "returned_sample" not in k:
+ loss_info[k] = v * batch_size
+ # update generator
+ optimizer_g.zero_grad()
+ scaler.scale(loss_g).backward()
+ scaler.step(optimizer_g)
+ scaler.update()
+
+ # summary stats
+ tot_loss = tot_loss + loss_info
+ except: # noqa
+ save_bad_model()
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if params.batch_idx_train % params.log_interval == 0:
+ cur_lr_g = max(scheduler_g.get_last_lr())
+ cur_lr_d = max(scheduler_d.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
+ f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate_g", cur_lr_g, params.batch_idx_train
+ )
+ tb_writer.add_scalar(
+ "train/learning_rate_d", cur_lr_d, params.batch_idx_train
+ )
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(
+ tb_writer, "train/tot_", params.batch_idx_train
+ )
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+ if "returned_sample" in stats_g:
+ speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
+ tb_writer.add_audio(
+ "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
+ )
+ tb_writer.add_audio(
+ "train/speech_", speech_, params.batch_idx_train, params.sampling_rate
+ )
+ tb_writer.add_image(
+ "train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC'
+ )
+ tb_writer.add_image(
+ "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
+ )
+
+ if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info, (speech_hat, speech) = compute_validation_loss(
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+ tb_writer.add_audio(
+ "train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate
+ )
+ tb_writer.add_audio(
+ "train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate
+ )
+
+ loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ tokenizer: Tokenizer,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+ rank: int = 0,
+) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
+ """Run the validation process."""
+ model.eval()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations
+ tot_loss = MetricsTracker()
+ returned_sample = None
+
+ with torch.no_grad():
+ for batch_idx, batch in enumerate(valid_dl):
+ batch_size = len(batch["tokens"])
+ audio, audio_lens, features, features_lens, tokens, tokens_lens = \
+ prepare_input(batch, tokenizer, device)
+
+ loss_info = MetricsTracker()
+ loss_info['samples'] = batch_size
+
+ # forward discriminator
+ loss_d, stats_d = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=False,
+ )
+ assert loss_d.requires_grad is False
+ for k, v in stats_d.items():
+ loss_info[k] = v * batch_size
+
+ # forward generator
+ loss_g, stats_g = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=True,
+ )
+ assert loss_g.requires_grad is False
+ for k, v in stats_g.items():
+ loss_info[k] = v * batch_size
+
+ # summary stats
+ tot_loss = tot_loss + loss_info
+
+ # infer for first batch:
+ if batch_idx == 0 and rank == 0:
+ inner_model = model.module if isinstance(model, DDP) else model
+ audio_pred, _, duration = inner_model.inference(
+ text=tokens[0, :tokens_lens[0].item()]
+ )
+ audio_pred = audio_pred.data.cpu().numpy()
+ audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
+ assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
+ audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy()
+ returned_sample = (audio_pred, audio_gt)
+
+ if world_size > 1:
+ tot_loss.reduce(device)
+
+ loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss, returned_sample
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ tokenizer: Tokenizer,
+ optimizer_g: torch.optim.Optimizer,
+ optimizer_d: torch.optim.Optimizer,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ audio, audio_lens, features, features_lens, tokens, tokens_lens = \
+ prepare_input(batch, tokenizer, device)
+ try:
+ # for discriminator
+ with autocast(enabled=params.use_fp16):
+ loss_d, stats_d = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=False,
+ )
+ optimizer_d.zero_grad()
+ loss_d.backward()
+ # for generator
+ with autocast(enabled=params.use_fp16):
+ loss_g, stats_g = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=True,
+ )
+ optimizer_g.zero_grad()
+ loss_g.backward()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ tokenizer = Tokenizer(params.tokens)
+ params.blank_id = tokenizer.blank_id
+ params.oov_id = tokenizer.oov_id
+ params.vocab_size = tokenizer.vocab_size
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+ generator = model.generator
+ discriminator = model.discriminator
+
+ num_param_g = sum([p.numel() for p in generator.parameters()])
+ logging.info(f"Number of parameters in generator: {num_param_g}")
+ num_param_d = sum([p.numel() for p in discriminator.parameters()])
+ logging.info(f"Number of parameters in discriminator: {num_param_d}")
+ logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer_g = torch.optim.AdamW(
+ generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
+ )
+ optimizer_d = torch.optim.AdamW(
+ discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
+ )
+
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
+
+ if checkpoints is not None:
+ # load state_dict for optimizers
+ if "optimizer_g" in checkpoints:
+ logging.info("Loading optimizer_g state dict")
+ optimizer_g.load_state_dict(checkpoints["optimizer_g"])
+ if "optimizer_d" in checkpoints:
+ logging.info("Loading optimizer_d state dict")
+ optimizer_d.load_state_dict(checkpoints["optimizer_d"])
+
+ # load state_dict for schedulers
+ if "scheduler_g" in checkpoints:
+ logging.info("Loading scheduler_g state dict")
+ scheduler_g.load_state_dict(checkpoints["scheduler_g"])
+ if "scheduler_d" in checkpoints:
+ logging.info("Loading scheduler_d state dict")
+ scheduler_d.load_state_dict(checkpoints["scheduler_d"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ ljspeech = LJSpeechTtsDataModule(args)
+
+ train_cuts = ljspeech.train_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+ train_dl = ljspeech.train_dataloaders(train_cuts)
+
+ valid_cuts = ljspeech.valid_cuts()
+ valid_dl = ljspeech.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ tokenizer=tokenizer,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ logging.info(f"Start epoch {epoch}")
+
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ params.cur_epoch = epoch
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ scheduler_g=scheduler_g,
+ scheduler_d=scheduler_d,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint(
+ filename=filename,
+ params=params,
+ model=model,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ scheduler_g=scheduler_g,
+ scheduler_d=scheduler_d,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ if rank == 0:
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+ # step per epoch
+ scheduler_g.step()
+ scheduler_d.step()
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ LJSpeechTtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/ljspeech/TTS/vits/transform.py b/egs/ljspeech/TTS/vits/transform.py
new file mode 100644
index 000000000..c20d13130
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/transform.py
@@ -0,0 +1,218 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py
+
+"""Flow-related transformation.
+
+This code is derived from https://github.com/bayesiains/nflows.
+
+"""
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+DEFAULT_MIN_BIN_WIDTH = 1e-3
+DEFAULT_MIN_BIN_HEIGHT = 1e-3
+DEFAULT_MIN_DERIVATIVE = 1e-3
+
+
+# TODO(kan-bayashi): Documentation and type hint
+def piecewise_rational_quadratic_transform(
+ inputs,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=False,
+ tails=None,
+ tail_bound=1.0,
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+ if tails is None:
+ spline_fn = rational_quadratic_spline
+ spline_kwargs = {}
+ else:
+ spline_fn = unconstrained_rational_quadratic_spline
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
+
+ outputs, logabsdet = spline_fn(
+ inputs=inputs,
+ unnormalized_widths=unnormalized_widths,
+ unnormalized_heights=unnormalized_heights,
+ unnormalized_derivatives=unnormalized_derivatives,
+ inverse=inverse,
+ min_bin_width=min_bin_width,
+ min_bin_height=min_bin_height,
+ min_derivative=min_derivative,
+ **spline_kwargs
+ )
+ return outputs, logabsdet
+
+
+# TODO(kan-bayashi): Documentation and type hint
+def unconstrained_rational_quadratic_spline(
+ inputs,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=False,
+ tails="linear",
+ tail_bound=1.0,
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
+ outside_interval_mask = ~inside_interval_mask
+
+ outputs = torch.zeros_like(inputs)
+ logabsdet = torch.zeros_like(inputs)
+
+ if tails == "linear":
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
+ constant = np.log(np.exp(1 - min_derivative) - 1)
+ unnormalized_derivatives[..., 0] = constant
+ unnormalized_derivatives[..., -1] = constant
+
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
+ logabsdet[outside_interval_mask] = 0
+ else:
+ raise RuntimeError("{} tails are not implemented.".format(tails))
+
+ (
+ outputs[inside_interval_mask],
+ logabsdet[inside_interval_mask],
+ ) = rational_quadratic_spline(
+ inputs=inputs[inside_interval_mask],
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
+ inverse=inverse,
+ left=-tail_bound,
+ right=tail_bound,
+ bottom=-tail_bound,
+ top=tail_bound,
+ min_bin_width=min_bin_width,
+ min_bin_height=min_bin_height,
+ min_derivative=min_derivative,
+ )
+
+ return outputs, logabsdet
+
+
+# TODO(kan-bayashi): Documentation and type hint
+def rational_quadratic_spline(
+ inputs,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=False,
+ left=0.0,
+ right=1.0,
+ bottom=0.0,
+ top=1.0,
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+ if torch.min(inputs) < left or torch.max(inputs) > right:
+ raise ValueError("Input to a transform is not within its domain")
+
+ num_bins = unnormalized_widths.shape[-1]
+
+ if min_bin_width * num_bins > 1.0:
+ raise ValueError("Minimal bin width too large for the number of bins")
+ if min_bin_height * num_bins > 1.0:
+ raise ValueError("Minimal bin height too large for the number of bins")
+
+ widths = F.softmax(unnormalized_widths, dim=-1)
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
+ cumwidths = torch.cumsum(widths, dim=-1)
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
+ cumwidths = (right - left) * cumwidths + left
+ cumwidths[..., 0] = left
+ cumwidths[..., -1] = right
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
+
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
+
+ heights = F.softmax(unnormalized_heights, dim=-1)
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
+ cumheights = torch.cumsum(heights, dim=-1)
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
+ cumheights = (top - bottom) * cumheights + bottom
+ cumheights[..., 0] = bottom
+ cumheights[..., -1] = top
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
+
+ if inverse:
+ bin_idx = _searchsorted(cumheights, inputs)[..., None]
+ else:
+ bin_idx = _searchsorted(cumwidths, inputs)[..., None]
+
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
+
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
+ delta = heights / widths
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
+
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
+
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
+
+ if inverse:
+ a = (inputs - input_cumheights) * (
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
+ ) + input_heights * (input_delta - input_derivatives)
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
+ )
+ c = -input_delta * (inputs - input_cumheights)
+
+ discriminant = b.pow(2) - 4 * a * c
+ assert (discriminant >= 0).all()
+
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
+ outputs = root * input_bin_widths + input_cumwidths
+
+ theta_one_minus_theta = root * (1 - root)
+ denominator = input_delta + (
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
+ * theta_one_minus_theta
+ )
+ derivative_numerator = input_delta.pow(2) * (
+ input_derivatives_plus_one * root.pow(2)
+ + 2 * input_delta * theta_one_minus_theta
+ + input_derivatives * (1 - root).pow(2)
+ )
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
+
+ return outputs, -logabsdet
+ else:
+ theta = (inputs - input_cumwidths) / input_bin_widths
+ theta_one_minus_theta = theta * (1 - theta)
+
+ numerator = input_heights * (
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
+ )
+ denominator = input_delta + (
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
+ * theta_one_minus_theta
+ )
+ outputs = input_cumheights + numerator / denominator
+
+ derivative_numerator = input_delta.pow(2) * (
+ input_derivatives_plus_one * theta.pow(2)
+ + 2 * input_delta * theta_one_minus_theta
+ + input_derivatives * (1 - theta).pow(2)
+ )
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
+
+ return outputs, logabsdet
+
+
+def _searchsorted(bin_locations, inputs, eps=1e-6):
+ bin_locations[..., -1] += eps
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py
new file mode 100644
index 000000000..0fcbb92c1
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/tts_datamodule.py
@@ -0,0 +1,325 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ SpeechSynthesisDataset,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
+ AudioSamples,
+ OnTheFlyFeatures,
+)
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class LJSpeechTtsDataModule:
+ """
+ DataModule for tts experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="TTS data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/spectrogram"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, each batch will have the "
+ "field: batch['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to create train dataset")
+ train = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ sampling_rate = 22050
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=1024 / sampling_rate, # (in second),
+ frame_shift=256 / sampling_rate, # (in second)
+ use_fft_mag=True,
+ )
+ train = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ sampling_rate = 22050
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=1024 / sampling_rate, # (in second),
+ frame_shift=256 / sampling_rate, # (in second)
+ use_fft_mag=True,
+ )
+ validate = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create valid dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.info("About to create test dataset")
+ if self.args.on_the_fly_feats:
+ sampling_rate = 22050
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=1024 / sampling_rate, # (in second),
+ frame_shift=256 / sampling_rate, # (in second)
+ use_fft_mag=True,
+ )
+ test = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ test = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ test_sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=test_sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get train cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
+ )
+
+ @lru_cache()
+ def valid_cuts(self) -> CutSet:
+ logging.info("About to get validation cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_cuts(self) -> CutSet:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
+ )
diff --git a/egs/ljspeech/TTS/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py
new file mode 100644
index 000000000..2a3dae900
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/utils.py
@@ -0,0 +1,265 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, List, Optional, Tuple, Union
+import collections
+import logging
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from lhotse.dataset.sampling.base import CutSampler
+from pathlib import Path
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torch.utils.tensorboard import SummaryWriter
+
+
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
+def get_random_segments(
+ x: torch.Tensor,
+ x_lengths: torch.Tensor,
+ segment_size: int,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Get random segments.
+
+ Args:
+ x (Tensor): Input tensor (B, C, T).
+ x_lengths (Tensor): Length tensor (B,).
+ segment_size (int): Segment size.
+
+ Returns:
+ Tensor: Segmented tensor (B, C, segment_size).
+ Tensor: Start index tensor (B,).
+
+ """
+ b, c, t = x.size()
+ max_start_idx = x_lengths - segment_size
+ max_start_idx[max_start_idx < 0] = 0
+ start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to(
+ dtype=torch.long,
+ )
+ segments = get_segments(x, start_idxs, segment_size)
+
+ return segments, start_idxs
+
+
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
+def get_segments(
+ x: torch.Tensor,
+ start_idxs: torch.Tensor,
+ segment_size: int,
+) -> torch.Tensor:
+ """Get segments.
+
+ Args:
+ x (Tensor): Input tensor (B, C, T).
+ start_idxs (Tensor): Start index tensor (B,).
+ segment_size (int): Segment size.
+
+ Returns:
+ Tensor: Segmented tensor (B, C, segment_size).
+
+ """
+ b, c, t = x.size()
+ segments = x.new_zeros(b, c, segment_size)
+ for i, start_idx in enumerate(start_idxs):
+ segments[i] = x[i, :, start_idx : start_idx + segment_size]
+ return segments
+
+
+# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py
+def intersperse(sequence, item=0):
+ result = [item] * (len(sequence) * 2 + 1)
+ result[1::2] = sequence
+ return result
+
+
+# from https://github.com/jaywalnut310/vits/blob/main/utils.py
+MATPLOTLIB_FLAG = False
+
+
+def plot_feature(spectrogram):
+ global MATPLOTLIB_FLAG
+ if not MATPLOTLIB_FLAG:
+ import matplotlib
+ matplotlib.use("Agg")
+ MATPLOTLIB_FLAG = True
+ mpl_logger = logging.getLogger('matplotlib')
+ mpl_logger.setLevel(logging.WARNING)
+ import matplotlib.pylab as plt
+ import numpy as np
+
+ fig, ax = plt.subplots(figsize=(10, 2))
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
+ interpolation='none')
+ plt.colorbar(im, ax=ax)
+ plt.xlabel("Frames")
+ plt.ylabel("Channels")
+ plt.tight_layout()
+
+ fig.canvas.draw()
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ plt.close()
+ return data
+
+
+class MetricsTracker(collections.defaultdict):
+ def __init__(self):
+ # Passing the type 'int' to the base-class constructor
+ # makes undefined items default to int() which is zero.
+ # This class will play a role as metrics tracker.
+ # It can record many metrics, including but not limited to loss.
+ super(MetricsTracker, self).__init__(int)
+
+ def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
+ ans = MetricsTracker()
+ for k, v in self.items():
+ ans[k] = v
+ for k, v in other.items():
+ ans[k] = ans[k] + v
+ return ans
+
+ def __mul__(self, alpha: float) -> "MetricsTracker":
+ ans = MetricsTracker()
+ for k, v in self.items():
+ ans[k] = v * alpha
+ return ans
+
+ def __str__(self) -> str:
+ ans = ""
+ for k, v in self.norm_items():
+ norm_value = "%.4g" % v
+ ans += str(k) + "=" + str(norm_value) + ", "
+ samples = "%.2f" % self["samples"]
+ ans += "over " + str(samples) + " samples."
+ return ans
+
+ def norm_items(self) -> List[Tuple[str, float]]:
+ """
+ Returns a list of pairs, like:
+ [('loss_1', 0.1), ('loss_2', 0.07)]
+ """
+ samples = self["samples"] if "samples" in self else 1
+ ans = []
+ for k, v in self.items():
+ if k == "samples":
+ continue
+ norm_value = float(v) / samples
+ ans.append((k, norm_value))
+ return ans
+
+ def reduce(self, device):
+ """
+ Reduce using torch.distributed, which I believe ensures that
+ all processes get the total.
+ """
+ keys = sorted(self.keys())
+ s = torch.tensor([float(self[k]) for k in keys], device=device)
+ dist.all_reduce(s, op=dist.ReduceOp.SUM)
+ for k, v in zip(keys, s.cpu().tolist()):
+ self[k] = v
+
+ def write_summary(
+ self,
+ tb_writer: SummaryWriter,
+ prefix: str,
+ batch_idx: int,
+ ) -> None:
+ """Add logging information to a TensorBoard writer.
+
+ Args:
+ tb_writer: a TensorBoard writer
+ prefix: a prefix for the name of the loss, e.g. "train/valid_",
+ or "train/current_"
+ batch_idx: The current batch index, used as the x-axis of the plot.
+ """
+ for k, v in self.norm_items():
+ tb_writer.add_scalar(prefix + k, v, batch_idx)
+
+
+# checkpoint saving and loading
+LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
+
+
+def save_checkpoint(
+ filename: Path,
+ model: Union[nn.Module, DDP],
+ params: Optional[Dict[str, Any]] = None,
+ optimizer_g: Optional[Optimizer] = None,
+ optimizer_d: Optional[Optimizer] = None,
+ scheduler_g: Optional[LRSchedulerType] = None,
+ scheduler_d: Optional[LRSchedulerType] = None,
+ scaler: Optional[GradScaler] = None,
+ sampler: Optional[CutSampler] = None,
+ rank: int = 0,
+) -> None:
+ """Save training information to a file.
+
+ Args:
+ filename:
+ The checkpoint filename.
+ model:
+ The model to be saved. We only save its `state_dict()`.
+ model_avg:
+ The stored model averaged from the start of training.
+ params:
+ User defined parameters, e.g., epoch, loss.
+ optimizer_g:
+ The optimizer for generator used in the training.
+ Its `state_dict` will be saved.
+ optimizer_d:
+ The optimizer for discriminator used in the training.
+ Its `state_dict` will be saved.
+ scheduler_g:
+ The learning rate scheduler for generator used in the training.
+ Its `state_dict` will be saved.
+ scheduler_d:
+ The learning rate scheduler for discriminator used in the training.
+ Its `state_dict` will be saved.
+ scalar:
+ The GradScaler to be saved. We only save its `state_dict()`.
+ rank:
+ Used in DDP. We save checkpoint only for the node whose rank is 0.
+ Returns:
+ Return None.
+ """
+ if rank != 0:
+ return
+
+ logging.info(f"Saving checkpoint to {filename}")
+
+ if isinstance(model, DDP):
+ model = model.module
+
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
+ "optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
+ "scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
+ "scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
+ "grad_scaler": scaler.state_dict() if scaler is not None else None,
+ "sampler": sampler.state_dict() if sampler is not None else None,
+ }
+
+ if params:
+ for k, v in params.items():
+ assert k not in checkpoint
+ checkpoint[k] = v
+
+ torch.save(checkpoint, filename)
diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py
new file mode 100644
index 000000000..d5e20a578
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/vits.py
@@ -0,0 +1,610 @@
+# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""VITS module for GAN-TTS task."""
+
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from torch.cuda.amp import autocast
+
+from hifigan import (
+ HiFiGANMultiPeriodDiscriminator,
+ HiFiGANMultiScaleDiscriminator,
+ HiFiGANMultiScaleMultiPeriodDiscriminator,
+ HiFiGANPeriodDiscriminator,
+ HiFiGANScaleDiscriminator,
+)
+from loss import (
+ DiscriminatorAdversarialLoss,
+ FeatureMatchLoss,
+ GeneratorAdversarialLoss,
+ KLDivergenceLoss,
+ MelSpectrogramLoss,
+)
+from utils import get_segments
+from generator import VITSGenerator
+
+
+AVAILABLE_GENERATERS = {
+ "vits_generator": VITSGenerator,
+}
+AVAILABLE_DISCRIMINATORS = {
+ "hifigan_period_discriminator": HiFiGANPeriodDiscriminator,
+ "hifigan_scale_discriminator": HiFiGANScaleDiscriminator,
+ "hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator,
+ "hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator,
+ "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
+}
+
+
+class VITS(nn.Module):
+ """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`
+ """
+
+ def __init__(
+ self,
+ # generator related
+ vocab_size: int,
+ feature_dim: int = 513,
+ sampling_rate: int = 22050,
+ generator_type: str = "vits_generator",
+ generator_params: Dict[str, Any] = {
+ "hidden_channels": 192,
+ "spks": None,
+ "langs": None,
+ "spk_embed_dim": None,
+ "global_channels": -1,
+ "segment_size": 32,
+ "text_encoder_attention_heads": 2,
+ "text_encoder_ffn_expand": 4,
+ "text_encoder_cnn_module_kernel": 5,
+ "text_encoder_blocks": 6,
+ "text_encoder_dropout_rate": 0.1,
+ "decoder_kernel_size": 7,
+ "decoder_channels": 512,
+ "decoder_upsample_scales": [8, 8, 2, 2],
+ "decoder_upsample_kernel_sizes": [16, 16, 4, 4],
+ "decoder_resblock_kernel_sizes": [3, 7, 11],
+ "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ "use_weight_norm_in_decoder": True,
+ "posterior_encoder_kernel_size": 5,
+ "posterior_encoder_layers": 16,
+ "posterior_encoder_stacks": 1,
+ "posterior_encoder_base_dilation": 1,
+ "posterior_encoder_dropout_rate": 0.0,
+ "use_weight_norm_in_posterior_encoder": True,
+ "flow_flows": 4,
+ "flow_kernel_size": 5,
+ "flow_base_dilation": 1,
+ "flow_layers": 4,
+ "flow_dropout_rate": 0.0,
+ "use_weight_norm_in_flow": True,
+ "use_only_mean_in_flow": True,
+ "stochastic_duration_predictor_kernel_size": 3,
+ "stochastic_duration_predictor_dropout_rate": 0.5,
+ "stochastic_duration_predictor_flows": 4,
+ "stochastic_duration_predictor_dds_conv_layers": 3,
+ },
+ # discriminator related
+ discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator",
+ discriminator_params: Dict[str, Any] = {
+ "scales": 1,
+ "scale_downsample_pooling": "AvgPool1d",
+ "scale_downsample_pooling_params": {
+ "kernel_size": 4,
+ "stride": 2,
+ "padding": 2,
+ },
+ "scale_discriminator_params": {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [15, 41, 5, 3],
+ "channels": 128,
+ "max_downsample_channels": 1024,
+ "max_groups": 16,
+ "bias": True,
+ "downsample_scales": [2, 2, 4, 4, 1],
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ "use_weight_norm": True,
+ "use_spectral_norm": False,
+ },
+ "follow_official_norm": False,
+ "periods": [2, 3, 5, 7, 11],
+ "period_discriminator_params": {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [5, 3],
+ "channels": 32,
+ "downsample_scales": [3, 3, 3, 3, 1],
+ "max_downsample_channels": 1024,
+ "bias": True,
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ "use_weight_norm": True,
+ "use_spectral_norm": False,
+ },
+ },
+ # loss related
+ generator_adv_loss_params: Dict[str, Any] = {
+ "average_by_discriminators": False,
+ "loss_type": "mse",
+ },
+ discriminator_adv_loss_params: Dict[str, Any] = {
+ "average_by_discriminators": False,
+ "loss_type": "mse",
+ },
+ feat_match_loss_params: Dict[str, Any] = {
+ "average_by_discriminators": False,
+ "average_by_layers": False,
+ "include_final_outputs": True,
+ },
+ mel_loss_params: Dict[str, Any] = {
+ "frame_shift": 256,
+ "frame_length": 1024,
+ "n_mels": 80,
+ },
+ lambda_adv: float = 1.0,
+ lambda_mel: float = 45.0,
+ lambda_feat_match: float = 2.0,
+ lambda_dur: float = 1.0,
+ lambda_kl: float = 1.0,
+ cache_generator_outputs: bool = True,
+ ):
+ """Initialize VITS module.
+
+ Args:
+ idim (int): Input vocabrary size.
+ odim (int): Acoustic feature dimension. The actual output channels will
+ be 1 since VITS is the end-to-end text-to-wave model but for the
+ compatibility odim is used to indicate the acoustic feature dimension.
+ sampling_rate (int): Sampling rate, not used for the training but it will
+ be referred in saving waveform during the inference.
+ generator_type (str): Generator type.
+ generator_params (Dict[str, Any]): Parameter dict for generator.
+ discriminator_type (str): Discriminator type.
+ discriminator_params (Dict[str, Any]): Parameter dict for discriminator.
+ generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator
+ adversarial loss.
+ discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for
+ discriminator adversarial loss.
+ feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss.
+ mel_loss_params (Dict[str, Any]): Parameter dict for mel loss.
+ lambda_adv (float): Loss scaling coefficient for adversarial loss.
+ lambda_mel (float): Loss scaling coefficient for mel spectrogram loss.
+ lambda_feat_match (float): Loss scaling coefficient for feat match loss.
+ lambda_dur (float): Loss scaling coefficient for duration loss.
+ lambda_kl (float): Loss scaling coefficient for KL divergence loss.
+ cache_generator_outputs (bool): Whether to cache generator outputs.
+
+ """
+ super().__init__()
+
+ # define modules
+ generator_class = AVAILABLE_GENERATERS[generator_type]
+ if generator_type == "vits_generator":
+ # NOTE(kan-bayashi): Update parameters for the compatibility.
+ # The idim and odim is automatically decided from input data,
+ # where idim represents #vocabularies and odim represents
+ # the input acoustic feature dimension.
+ generator_params.update(vocabs=vocab_size, aux_channels=feature_dim)
+ self.generator = generator_class(
+ **generator_params,
+ )
+ discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type]
+ self.discriminator = discriminator_class(
+ **discriminator_params,
+ )
+ self.generator_adv_loss = GeneratorAdversarialLoss(
+ **generator_adv_loss_params,
+ )
+ self.discriminator_adv_loss = DiscriminatorAdversarialLoss(
+ **discriminator_adv_loss_params,
+ )
+ self.feat_match_loss = FeatureMatchLoss(
+ **feat_match_loss_params,
+ )
+ mel_loss_params.update(sampling_rate=sampling_rate)
+ self.mel_loss = MelSpectrogramLoss(
+ **mel_loss_params,
+ )
+ self.kl_loss = KLDivergenceLoss()
+
+ # coefficients
+ self.lambda_adv = lambda_adv
+ self.lambda_mel = lambda_mel
+ self.lambda_kl = lambda_kl
+ self.lambda_feat_match = lambda_feat_match
+ self.lambda_dur = lambda_dur
+
+ # cache
+ self.cache_generator_outputs = cache_generator_outputs
+ self._cache = None
+
+ # store sampling rate for saving wav file
+ # (not used for the training)
+ self.sampling_rate = sampling_rate
+
+ # store parameters for test compatibility
+ self.spks = self.generator.spks
+ self.langs = self.generator.langs
+ self.spk_embed_dim = self.generator.spk_embed_dim
+
+ def forward(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ feats: torch.Tensor,
+ feats_lengths: torch.Tensor,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ return_sample: bool = False,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ forward_generator: bool = True,
+ ) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ """Perform generator forward.
+
+ Args:
+ text (Tensor): Text index tensor (B, T_text).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats (Tensor): Feature tensor (B, T_feats, aux_channels).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ speech (Tensor): Speech waveform tensor (B, T_wav).
+ speech_lengths (Tensor): Speech length tensor (B,).
+ sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
+ spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
+ lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
+ forward_generator (bool): Whether to forward generator.
+
+ Returns:
+ - loss (Tensor): Loss scalar tensor.
+ - stats (Dict[str, float]): Statistics to be monitored.
+ """
+ if forward_generator:
+ return self._forward_generator(
+ text=text,
+ text_lengths=text_lengths,
+ feats=feats,
+ feats_lengths=feats_lengths,
+ speech=speech,
+ speech_lengths=speech_lengths,
+ return_sample=return_sample,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ )
+ else:
+ return self._forward_discrminator(
+ text=text,
+ text_lengths=text_lengths,
+ feats=feats,
+ feats_lengths=feats_lengths,
+ speech=speech,
+ speech_lengths=speech_lengths,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ )
+
+ def _forward_generator(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ feats: torch.Tensor,
+ feats_lengths: torch.Tensor,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ return_sample: bool = False,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ """Perform generator forward.
+
+ Args:
+ text (Tensor): Text index tensor (B, T_text).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats (Tensor): Feature tensor (B, T_feats, aux_channels).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ speech (Tensor): Speech waveform tensor (B, T_wav).
+ speech_lengths (Tensor): Speech length tensor (B,).
+ sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
+ spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
+ lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
+
+ Returns:
+ * loss (Tensor): Loss scalar tensor.
+ * stats (Dict[str, float]): Statistics to be monitored.
+ """
+ # setup
+ feats = feats.transpose(1, 2)
+ speech = speech.unsqueeze(1)
+
+ # calculate generator outputs
+ reuse_cache = True
+ if not self.cache_generator_outputs or self._cache is None:
+ reuse_cache = False
+ outs = self.generator(
+ text=text,
+ text_lengths=text_lengths,
+ feats=feats,
+ feats_lengths=feats_lengths,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ )
+ else:
+ outs = self._cache
+
+ # store cache
+ if self.training and self.cache_generator_outputs and not reuse_cache:
+ self._cache = outs
+
+ # parse outputs
+ speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
+ _, z_p, m_p, logs_p, _, logs_q = outs_
+ speech_ = get_segments(
+ x=speech,
+ start_idxs=start_idxs * self.generator.upsample_factor,
+ segment_size=self.generator.segment_size * self.generator.upsample_factor,
+ )
+
+ # calculate discriminator outputs
+ p_hat = self.discriminator(speech_hat_)
+ with torch.no_grad():
+ # do not store discriminator gradient in generator turn
+ p = self.discriminator(speech_)
+
+ # calculate losses
+ with autocast(enabled=False):
+ if not return_sample:
+ mel_loss = self.mel_loss(speech_hat_, speech_)
+ else:
+ mel_loss, (mel_hat_, mel_) = self.mel_loss(
+ speech_hat_, speech_, return_mel=True
+ )
+ kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
+ dur_loss = torch.sum(dur_nll.float())
+ adv_loss = self.generator_adv_loss(p_hat)
+ feat_match_loss = self.feat_match_loss(p_hat, p)
+
+ mel_loss = mel_loss * self.lambda_mel
+ kl_loss = kl_loss * self.lambda_kl
+ dur_loss = dur_loss * self.lambda_dur
+ adv_loss = adv_loss * self.lambda_adv
+ feat_match_loss = feat_match_loss * self.lambda_feat_match
+ loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
+
+ stats = dict(
+ generator_loss=loss.item(),
+ generator_mel_loss=mel_loss.item(),
+ generator_kl_loss=kl_loss.item(),
+ generator_dur_loss=dur_loss.item(),
+ generator_adv_loss=adv_loss.item(),
+ generator_feat_match_loss=feat_match_loss.item(),
+ )
+
+ if return_sample:
+ stats["returned_sample"] = (
+ speech_hat_[0].data.cpu().numpy(),
+ speech_[0].data.cpu().numpy(),
+ mel_hat_[0].data.cpu().numpy(),
+ mel_[0].data.cpu().numpy(),
+ )
+
+ # reset cache
+ if reuse_cache or not self.training:
+ self._cache = None
+
+ return loss, stats
+
+ def _forward_discrminator(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ feats: torch.Tensor,
+ feats_lengths: torch.Tensor,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ """Perform discriminator forward.
+
+ Args:
+ text (Tensor): Text index tensor (B, T_text).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats (Tensor): Feature tensor (B, T_feats, aux_channels).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ speech (Tensor): Speech waveform tensor (B, T_wav).
+ speech_lengths (Tensor): Speech length tensor (B,).
+ sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
+ spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
+ lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
+
+ Returns:
+ * loss (Tensor): Loss scalar tensor.
+ * stats (Dict[str, float]): Statistics to be monitored.
+ """
+ # setup
+ feats = feats.transpose(1, 2)
+ speech = speech.unsqueeze(1)
+
+ # calculate generator outputs
+ reuse_cache = True
+ if not self.cache_generator_outputs or self._cache is None:
+ reuse_cache = False
+ outs = self.generator(
+ text=text,
+ text_lengths=text_lengths,
+ feats=feats,
+ feats_lengths=feats_lengths,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ )
+ else:
+ outs = self._cache
+
+ # store cache
+ if self.cache_generator_outputs and not reuse_cache:
+ self._cache = outs
+
+ # parse outputs
+ speech_hat_, _, _, start_idxs, *_ = outs
+ speech_ = get_segments(
+ x=speech,
+ start_idxs=start_idxs * self.generator.upsample_factor,
+ segment_size=self.generator.segment_size * self.generator.upsample_factor,
+ )
+
+ # calculate discriminator outputs
+ p_hat = self.discriminator(speech_hat_.detach())
+ p = self.discriminator(speech_)
+
+ # calculate losses
+ with autocast(enabled=False):
+ real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
+ loss = real_loss + fake_loss
+
+ stats = dict(
+ discriminator_loss=loss.item(),
+ discriminator_real_loss=real_loss.item(),
+ discriminator_fake_loss=fake_loss.item(),
+ )
+
+ # reset cache
+ if reuse_cache or not self.training:
+ self._cache = None
+
+ return loss, stats
+
+ def inference(
+ self,
+ text: torch.Tensor,
+ feats: Optional[torch.Tensor] = None,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ durations: Optional[torch.Tensor] = None,
+ noise_scale: float = 0.667,
+ noise_scale_dur: float = 0.8,
+ alpha: float = 1.0,
+ max_len: Optional[int] = None,
+ use_teacher_forcing: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run inference for single sample.
+
+ Args:
+ text (Tensor): Input text index tensor (T_text,).
+ feats (Tensor): Feature tensor (T_feats, aux_channels).
+ sids (Tensor): Speaker index tensor (1,).
+ spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,).
+ lids (Tensor): Language index tensor (1,).
+ durations (Tensor): Ground-truth duration tensor (T_text,).
+ noise_scale (float): Noise scale value for flow.
+ noise_scale_dur (float): Noise scale value for duration predictor.
+ alpha (float): Alpha parameter to control the speed of generated speech.
+ max_len (Optional[int]): Maximum length.
+ use_teacher_forcing (bool): Whether to use teacher forcing.
+
+ Returns:
+ * wav (Tensor): Generated waveform tensor (T_wav,).
+ * att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
+ * duration (Tensor): Predicted duration tensor (T_text,).
+ """
+ # setup
+ text = text[None]
+ text_lengths = torch.tensor(
+ [text.size(1)],
+ dtype=torch.long,
+ device=text.device,
+ )
+ if sids is not None:
+ sids = sids.view(1)
+ if lids is not None:
+ lids = lids.view(1)
+ if durations is not None:
+ durations = durations.view(1, 1, -1)
+
+ # inference
+ if use_teacher_forcing:
+ assert feats is not None
+ feats = feats[None].transpose(1, 2)
+ feats_lengths = torch.tensor(
+ [feats.size(2)],
+ dtype=torch.long,
+ device=feats.device,
+ )
+ wav, att_w, dur = self.generator.inference(
+ text=text,
+ text_lengths=text_lengths,
+ feats=feats,
+ feats_lengths=feats_lengths,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ max_len=max_len,
+ use_teacher_forcing=use_teacher_forcing,
+ )
+ else:
+ wav, att_w, dur = self.generator.inference(
+ text=text,
+ text_lengths=text_lengths,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ dur=durations,
+ noise_scale=noise_scale,
+ noise_scale_dur=noise_scale_dur,
+ alpha=alpha,
+ max_len=max_len,
+ )
+ return wav.view(-1), att_w[0], dur[0]
+
+ def inference_batch(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ sids: Optional[torch.Tensor] = None,
+ durations: Optional[torch.Tensor] = None,
+ noise_scale: float = 0.667,
+ noise_scale_dur: float = 0.8,
+ alpha: float = 1.0,
+ max_len: Optional[int] = None,
+ use_teacher_forcing: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run inference for one batch.
+
+ Args:
+ text (Tensor): Input text index tensor (B, T_text).
+ text_lengths (Tensor): Input text index tensor (B,).
+ sids (Tensor): Speaker index tensor (B,).
+ noise_scale (float): Noise scale value for flow.
+ noise_scale_dur (float): Noise scale value for duration predictor.
+ alpha (float): Alpha parameter to control the speed of generated speech.
+ max_len (Optional[int]): Maximum length.
+
+ Returns:
+ * wav (Tensor): Generated waveform tensor (B, T_wav).
+ * att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
+ * duration (Tensor): Predicted duration tensor (B, T_text).
+ """
+ # inference
+ wav, att_w, dur = self.generator.inference(
+ text=text,
+ text_lengths=text_lengths,
+ sids=sids,
+ noise_scale=noise_scale,
+ noise_scale_dur=noise_scale_dur,
+ alpha=alpha,
+ max_len=max_len,
+ )
+ return wav, att_w, dur
diff --git a/egs/ljspeech/TTS/vits/wavenet.py b/egs/ljspeech/TTS/vits/wavenet.py
new file mode 100644
index 000000000..fbe1be52b
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/wavenet.py
@@ -0,0 +1,349 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""WaveNet modules.
+
+This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
+
+"""
+
+import math
+import logging
+
+from typing import Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+
+
+class WaveNet(torch.nn.Module):
+ """WaveNet with global conditioning."""
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ kernel_size: int = 3,
+ layers: int = 30,
+ stacks: int = 3,
+ base_dilation: int = 2,
+ residual_channels: int = 64,
+ aux_channels: int = -1,
+ gate_channels: int = 128,
+ skip_channels: int = 64,
+ global_channels: int = -1,
+ dropout_rate: float = 0.0,
+ bias: bool = True,
+ use_weight_norm: bool = True,
+ use_first_conv: bool = False,
+ use_last_conv: bool = False,
+ scale_residual: bool = False,
+ scale_skip_connect: bool = False,
+ ):
+ """Initialize WaveNet module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ kernel_size (int): Kernel size of dilated convolution.
+ layers (int): Number of residual block layers.
+ stacks (int): Number of stacks i.e., dilation cycles.
+ base_dilation (int): Base dilation factor.
+ residual_channels (int): Number of channels in residual conv.
+ gate_channels (int): Number of channels in gated conv.
+ skip_channels (int): Number of channels in skip conv.
+ aux_channels (int): Number of channels for local conditioning feature.
+ global_channels (int): Number of channels for global conditioning feature.
+ dropout_rate (float): Dropout rate. 0.0 means no dropout applied.
+ bias (bool): Whether to use bias parameter in conv layer.
+ use_weight_norm (bool): Whether to use weight norm. If set to true, it will
+ be applied to all of the conv layers.
+ use_first_conv (bool): Whether to use the first conv layers.
+ use_last_conv (bool): Whether to use the last conv layers.
+ scale_residual (bool): Whether to scale the residual outputs.
+ scale_skip_connect (bool): Whether to scale the skip connection outputs.
+
+ """
+ super().__init__()
+ self.layers = layers
+ self.stacks = stacks
+ self.kernel_size = kernel_size
+ self.base_dilation = base_dilation
+ self.use_first_conv = use_first_conv
+ self.use_last_conv = use_last_conv
+ self.scale_skip_connect = scale_skip_connect
+
+ # check the number of layers and stacks
+ assert layers % stacks == 0
+ layers_per_stack = layers // stacks
+
+ # define first convolution
+ if self.use_first_conv:
+ self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
+
+ # define residual blocks
+ self.conv_layers = torch.nn.ModuleList()
+ for layer in range(layers):
+ dilation = base_dilation ** (layer % layers_per_stack)
+ conv = ResidualBlock(
+ kernel_size=kernel_size,
+ residual_channels=residual_channels,
+ gate_channels=gate_channels,
+ skip_channels=skip_channels,
+ aux_channels=aux_channels,
+ global_channels=global_channels,
+ dilation=dilation,
+ dropout_rate=dropout_rate,
+ bias=bias,
+ scale_residual=scale_residual,
+ )
+ self.conv_layers += [conv]
+
+ # define output layers
+ if self.use_last_conv:
+ self.last_conv = torch.nn.Sequential(
+ torch.nn.ReLU(inplace=True),
+ Conv1d1x1(skip_channels, skip_channels, bias=True),
+ torch.nn.ReLU(inplace=True),
+ Conv1d1x1(skip_channels, out_channels, bias=True),
+ )
+
+ # apply weight norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: Optional[torch.Tensor] = None,
+ c: Optional[torch.Tensor] = None,
+ g: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T) if use_first_conv else
+ (B, residual_channels, T).
+ x_mask (Optional[Tensor]): Mask tensor (B, 1, T).
+ c (Optional[Tensor]): Local conditioning features (B, aux_channels, T).
+ g (Optional[Tensor]): Global conditioning features (B, global_channels, 1).
+
+ Returns:
+ Tensor: Output tensor (B, out_channels, T) if use_last_conv else
+ (B, residual_channels, T).
+
+ """
+ # encode to hidden representation
+ if self.use_first_conv:
+ x = self.first_conv(x)
+
+ # residual block
+ skips = 0.0
+ for f in self.conv_layers:
+ x, h = f(x, x_mask=x_mask, c=c, g=g)
+ skips = skips + h
+ x = skips
+ if self.scale_skip_connect:
+ x = x * math.sqrt(1.0 / len(self.conv_layers))
+
+ # apply final layers
+ if self.use_last_conv:
+ x = self.last_conv(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m: torch.nn.Module):
+ try:
+ logging.debug(f"Weight norm is removed from {m}.")
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ @staticmethod
+ def _get_receptive_field_size(
+ layers: int,
+ stacks: int,
+ kernel_size: int,
+ base_dilation: int,
+ ) -> int:
+ assert layers % stacks == 0
+ layers_per_cycle = layers // stacks
+ dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)]
+ return (kernel_size - 1) * sum(dilations) + 1
+
+ @property
+ def receptive_field_size(self) -> int:
+ """Return receptive field size."""
+ return self._get_receptive_field_size(
+ self.layers, self.stacks, self.kernel_size, self.base_dilation
+ )
+
+
+class Conv1d(torch.nn.Conv1d):
+ """Conv1d module with customized initialization."""
+
+ def __init__(self, *args, **kwargs):
+ """Initialize Conv1d module."""
+ super().__init__(*args, **kwargs)
+
+ def reset_parameters(self):
+ """Reset parameters."""
+ torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
+ if self.bias is not None:
+ torch.nn.init.constant_(self.bias, 0.0)
+
+
+class Conv1d1x1(Conv1d):
+ """1x1 Conv1d with customized initialization."""
+
+ def __init__(self, in_channels: int, out_channels: int, bias: bool):
+ """Initialize 1x1 Conv1d module."""
+ super().__init__(
+ in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
+ )
+
+
+class ResidualBlock(torch.nn.Module):
+ """Residual block module in WaveNet."""
+
+ def __init__(
+ self,
+ kernel_size: int = 3,
+ residual_channels: int = 64,
+ gate_channels: int = 128,
+ skip_channels: int = 64,
+ aux_channels: int = 80,
+ global_channels: int = -1,
+ dropout_rate: float = 0.0,
+ dilation: int = 1,
+ bias: bool = True,
+ scale_residual: bool = False,
+ ):
+ """Initialize ResidualBlock module.
+
+ Args:
+ kernel_size (int): Kernel size of dilation convolution layer.
+ residual_channels (int): Number of channels for residual connection.
+ skip_channels (int): Number of channels for skip connection.
+ aux_channels (int): Number of local conditioning channels.
+ dropout (float): Dropout probability.
+ dilation (int): Dilation factor.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ scale_residual (bool): Whether to scale the residual outputs.
+
+ """
+ super().__init__()
+ self.dropout_rate = dropout_rate
+ self.residual_channels = residual_channels
+ self.skip_channels = skip_channels
+ self.scale_residual = scale_residual
+
+ # check
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
+ assert gate_channels % 2 == 0
+
+ # dilation conv
+ padding = (kernel_size - 1) // 2 * dilation
+ self.conv = Conv1d(
+ residual_channels,
+ gate_channels,
+ kernel_size,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+
+ # local conditioning
+ if aux_channels > 0:
+ self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
+ else:
+ self.conv1x1_aux = None
+
+ # global conditioning
+ if global_channels > 0:
+ self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False)
+ else:
+ self.conv1x1_glo = None
+
+ # conv output is split into two groups
+ gate_out_channels = gate_channels // 2
+
+ # NOTE(kan-bayashi): concat two convs into a single conv for the efficiency
+ # (integrate res 1x1 + skip 1x1 convs)
+ self.conv1x1_out = Conv1d1x1(
+ gate_out_channels, residual_channels + skip_channels, bias=bias
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: Optional[torch.Tensor] = None,
+ c: Optional[torch.Tensor] = None,
+ g: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, residual_channels, T).
+ x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T).
+ c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+
+ Returns:
+ Tensor: Output tensor for residual connection (B, residual_channels, T).
+ Tensor: Output tensor for skip connection (B, skip_channels, T).
+
+ """
+ residual = x
+ x = F.dropout(x, p=self.dropout_rate, training=self.training)
+ x = self.conv(x)
+
+ # split into two part for gated activation
+ splitdim = 1
+ xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
+
+ # local conditioning
+ if c is not None:
+ c = self.conv1x1_aux(c)
+ ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
+ xa, xb = xa + ca, xb + cb
+
+ # global conditioning
+ if g is not None:
+ g = self.conv1x1_glo(g)
+ ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim)
+ xa, xb = xa + ga, xb + gb
+
+ x = torch.tanh(xa) * torch.sigmoid(xb)
+
+ # residual + skip 1x1 conv
+ x = self.conv1x1_out(x)
+ if x_mask is not None:
+ x = x * x_mask
+
+ # split integrated conv results
+ x, s = x.split([self.residual_channels, self.skip_channels], dim=1)
+
+ # for residual connection
+ x = x + residual
+ if self.scale_residual:
+ x = x * math.sqrt(0.5)
+
+ return x, s
diff --git a/egs/multi_zh_en/ASR/README.md b/egs/multi_zh_en/ASR/README.md
new file mode 100644
index 000000000..29341571d
--- /dev/null
+++ b/egs/multi_zh_en/ASR/README.md
@@ -0,0 +1,19 @@
+# Introduction
+
+This recipe includes scripts for training Zipformer model using both English and Chinese datasets.
+
+# Included Training Sets
+
+1. LibriSpeech (English)
+2. AiShell-2 (Chinese)
+3. TAL-CSASR (Code-Switching, Chinese and English)
+
+|Datset| Number of hours| URL|
+|---|---:|---|
+|**TOTAL**|2,547|---|
+|LibriSpeech|960|https://www.openslr.org/12/|
+|AiShell-2|1,000|http://www.aishelltech.com/aishell_2|
+|TAL-CSASR|587|https://ai.100tal.com/openData/voice|
+
+
+
diff --git a/egs/multi_zh_en/ASR/RESULTS.md b/egs/multi_zh_en/ASR/RESULTS.md
new file mode 100644
index 000000000..3562d6ac3
--- /dev/null
+++ b/egs/multi_zh_en/ASR/RESULTS.md
@@ -0,0 +1,44 @@
+## Results
+
+### Zh-En datasets bpe-based training results (Non-streaming) on Zipformer model
+
+This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1265) in icefall.
+
+#### Non-streaming (Byte-Level BPE vocab_size=2000)
+
+Best results (num of params : ~69M):
+
+The training command:
+
+```
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 35 \
+ --use-fp16 1 \
+ --max-duration 1000 \
+ --num-workers 8
+```
+
+The decoding command:
+
+```
+for method in greedy_search modified_beam_search fast_beam_search; do
+ ./zipformer/decode.py \
+ --epoch 34 \
+ --avg 19 \
+ --decoding-method $method
+done
+```
+
+Word Error Rates (WERs) listed below are produced by the checkpoint of the 20th epoch using greedy search and BPE model (# tokens is 2000).
+
+| Datasets | TAL-CSASR | TAL-CSASR | AiShell-2 | AiShell-2 | LibriSpeech | LibriSpeech |
+|----------------------|-----------|-----------|-----------|-----------|-------------|-------------|
+| Zipformer WER (%) | dev | test | dev | test | test-clean | test-other |
+| greedy_search | 6.65 | 6.69 | 6.57 | 7.03 | 2.43 | 5.70 |
+| modified_beam_search | 6.46 | 6.51 | 6.18 | 6.60 | 2.41 | 5.57 |
+| fast_beam_search | 6.57 | 6.68 | 6.40 | 6.74 | 2.40 | 5.56 |
+
+Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-zipformer-multi-zh-en-2023-11-22, which is trained on LibriSpeech 960-hour training set (with speed perturbation), TAL-CSASR training set (with speed perturbation) and AiShell-2 (w/o speed perturbation).
+
+
diff --git a/egs/multi_zh_en/ASR/local/compile_lg.py b/egs/multi_zh_en/ASR/local/compile_lg.py
new file mode 120000
index 000000000..462d6d3fb
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/compile_lg.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compile_lg.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/prepare_char.py b/egs/multi_zh_en/ASR/local/prepare_char.py
new file mode 120000
index 000000000..42743b544
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/prepare_char.py
@@ -0,0 +1 @@
+../../../aishell/ASR/local/prepare_char.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py b/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py
new file mode 100755
index 000000000..00514e6bb
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/prepare_for_bpe_model.py
@@ -0,0 +1,65 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script tokenizes the training transcript by CJK characters
+# and saves the result to transcript_chars.txt, which is used
+# to train the BPE model later.
+
+import argparse
+from pathlib import Path
+
+from tqdm.auto import tqdm
+
+from icefall.utils import tokenize_by_CJK_char
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ help="""Output directory.
+ The generated transcript_chars.txt is saved to this directory.
+ """,
+ )
+
+ parser.add_argument(
+ "--text",
+ type=str,
+ help="Training transcript.",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+ lang_dir = Path(args.lang_dir)
+ text = Path(args.text)
+
+ assert lang_dir.exists() and text.exists(), f"{lang_dir} or {text} does not exist!"
+
+ transcript_path = lang_dir / "transcript_chars.txt"
+
+ with open(text, "r", encoding="utf-8") as fin:
+ with open(transcript_path, "w+", encoding="utf-8") as fout:
+ for line in tqdm(fin):
+ fout.write(tokenize_by_CJK_char(line) + "\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/multi_zh_en/ASR/local/prepare_lang.py b/egs/multi_zh_en/ASR/local/prepare_lang.py
new file mode 120000
index 000000000..747f2ab39
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/prepare_lang.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py b/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py
new file mode 120000
index 000000000..9a0b44642
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/prepare_lang_bbpe.py
@@ -0,0 +1 @@
+../../../aishell/ASR/local/prepare_lang_bbpe.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py b/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py
new file mode 120000
index 000000000..36b40e7fc
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/prepare_lang_bpe.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_bpe.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/prepare_words.py b/egs/multi_zh_en/ASR/local/prepare_words.py
new file mode 120000
index 000000000..ef2b4eaf3
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/prepare_words.py
@@ -0,0 +1 @@
+../../../aishell2/ASR/local/prepare_words.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/text2segments.py b/egs/multi_zh_en/ASR/local/text2segments.py
new file mode 120000
index 000000000..7d68a39c3
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/text2segments.py
@@ -0,0 +1 @@
+../../../wenetspeech/ASR/local/text2segments.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/text2token.py b/egs/multi_zh_en/ASR/local/text2token.py
new file mode 120000
index 000000000..ce5cfd537
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/text2token.py
@@ -0,0 +1 @@
+../../../wenetspeech/ASR/local/text2token.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/train_bbpe_model.py b/egs/multi_zh_en/ASR/local/train_bbpe_model.py
new file mode 120000
index 000000000..7fb4a9f9d
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/train_bbpe_model.py
@@ -0,0 +1 @@
+../../../aishell/ASR/local/train_bbpe_model.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py b/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py
new file mode 120000
index 000000000..721bb48e7
--- /dev/null
+++ b/egs/multi_zh_en/ASR/local/validate_bpe_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/validate_bpe_lexicon.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/prepare.sh b/egs/multi_zh_en/ASR/prepare.sh
new file mode 100755
index 000000000..9f2be5a5c
--- /dev/null
+++ b/egs/multi_zh_en/ASR/prepare.sh
@@ -0,0 +1,149 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+vocab_sizes=(
+ 2000
+)
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+log "Dataset: musan"
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Soft link fbank of musan"
+ mkdir -p data/fbank
+ if [ -e ../../librispeech/ASR/data/fbank/.musan.done ]; then
+ cd data/fbank
+ ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_feats) .
+ ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/musan_cuts.jsonl.gz) .
+ cd ../..
+ else
+ log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 4 --stop-stage 4"
+ exit 1
+ fi
+fi
+
+log "Dataset: LibriSpeech"
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Soft link fbank of LibriSpeech"
+ mkdir -p data/fbank
+ if [ -e ../../librispeech/ASR/data/fbank/.librispeech.done ]; then
+ cd data/fbank
+ ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts*) .
+ ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats*) .
+ cd ../..
+ else
+ log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 3 --stop-stage 3"
+ exit 1
+ fi
+fi
+
+log "Dataset: AiShell-2"
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Soft link fbank of AiShell-2"
+ mkdir -p data/fbank
+ if [ -e ../../aishell2/ASR/data/fbank/.aishell2.done ]; then
+ cd data/fbank
+ ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_cuts*) .
+ ln -svf $(realpath ../../../../aishell2/ASR/data/fbank/aishell2_feats*) .
+ cd ../..
+ else
+ log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3"
+ exit 1
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Prepare Byte BPE based lang"
+ mkdir -p data/fbank
+ if [ ! -d ../../aishell2/ASR/data/lang_char ] && [ ! -d ./data/lang_char ]; then
+ log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3"
+ exit 1
+ fi
+
+ if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ] && [ ! -d ./data/lang_bpe_500 ]; then
+ log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 6 --stop-stage 6"
+ exit 1
+ fi
+
+ cd data/
+ if [ ! -d ./lang_char ]; then
+ ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) .
+ fi
+ if [ ! -d ./lang_bpe_500 ]; then
+ ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) .
+ fi
+ cd ../
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bbpe_${vocab_size}
+ mkdir -p $lang_dir
+
+ cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \
+ > $lang_dir/text
+
+ if [ ! -f $lang_dir/transcript_chars.txt ]; then
+ ./local/prepare_for_bpe_model.py \
+ --lang-dir ./$lang_dir \
+ --text $lang_dir/text
+ fi
+
+ if [ ! -f $lang_dir/text_words_segmentation ]; then
+ python3 ./local/text2segments.py \
+ --input-file ./data/lang_char/text \
+ --output-file $lang_dir/text_words_segmentation
+
+ cat ./data/lang_bpe_500/transcript_words.txt \
+ >> $lang_dir/text_words_segmentation
+
+ cat ./data/lang_char/text \
+ >> $lang_dir/text
+ fi
+
+ cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \
+ | sort -u | sed '/^$/d' | uniq > $lang_dir/words_no_ids.txt
+
+ if [ ! -f $lang_dir/words.txt ]; then
+ python3 ./local/prepare_words.py \
+ --input-file $lang_dir/words_no_ids.txt \
+ --output-file $lang_dir/words.txt
+ fi
+
+ if [ ! -f $lang_dir/bbpe.model ]; then
+ ./local/train_bbpe_model.py \
+ --lang-dir $lang_dir \
+ --vocab-size $vocab_size \
+ --transcript $lang_dir/text
+ fi
+
+ if [ ! -f $lang_dir/L_disambig.pt ]; then
+ ./local/prepare_lang_bbpe.py --lang-dir $lang_dir
+
+ log "Validating $lang_dir/lexicon.txt"
+ ./local/validate_bpe_lexicon.py \
+ --lexicon $lang_dir/lexicon.txt \
+ --bpe-model $lang_dir/bbpe.model
+ fi
+ done
+fi
+
diff --git a/egs/multi_zh_en/ASR/shared b/egs/multi_zh_en/ASR/shared
new file mode 120000
index 000000000..4cbd91a7e
--- /dev/null
+++ b/egs/multi_zh_en/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py
new file mode 100644
index 000000000..be6e94472
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/asr_datamodule.py
@@ -0,0 +1,385 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ K2SpeechRecognitionDataset,
+ SimpleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
+ OnTheFlyFeatures,
+)
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class AsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - augmentation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=300.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--concatenate-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, utterances (cuts) will be concatenated "
+ "to minimize the amount of padding.",
+ )
+ group.add_argument(
+ "--duration-factor",
+ type=float,
+ default=1.0,
+ help="Determines the maximum duration of a concatenated cut "
+ "relative to the duration of the longest cut in a batch.",
+ )
+ group.add_argument(
+ "--gap",
+ type=float,
+ default=1.0,
+ help="The amount of padding (in seconds) inserted between "
+ "concatenated cuts. This padding is filled with noise when "
+ "noise augmentation is used.",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+ transforms.append(
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ if self.args.concatenate_cuts:
+ logging.info(
+ f"Using cut concatenation with duration factor "
+ f"{self.args.duration_factor} and gap {self.args.gap}."
+ )
+ # Cut concatenation should be the first transform in the list,
+ # so that if we e.g. mix noise in, it will fill the gaps between
+ # different utterances.
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SpeechRecognitionDataset(
+ input_strategy=eval(self.args.input_strategy)(),
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ # NOTE: the PerturbSpeed transform should be added only if we
+ # remove it from data prep stage.
+ # Add on-the-fly speed perturbation; since originally it would
+ # have increased epoch size by 3, we will apply prob 2/3 and use
+ # 3x more epochs.
+ # Speed perturbation probably should come first before
+ # concatenation, but in principle the transforms order doesn't have
+ # to be strict (e.g. could be randomized)
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
+ # Drop feats to be on the safe side.
+ train = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=True,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+ if self.args.concatenate_cuts:
+ transforms = [
+ CutConcatenate(
+ duration_factor=self.args.duration_factor, gap=self.args.gap
+ )
+ ] + transforms
+
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = K2SpeechRecognitionDataset(
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SpeechRecognitionDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
diff --git a/egs/multi_zh_en/ASR/zipformer/beam_search.py b/egs/multi_zh_en/ASR/zipformer/beam_search.py
new file mode 120000
index 000000000..8e2c0a65c
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/beam_search.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/decode.py b/egs/multi_zh_en/ASR/zipformer/decode.py
new file mode 100755
index 000000000..e21e8f052
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/decode.py
@@ -0,0 +1,851 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./zipformer/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp \
+ --max-duration 600 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import AsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_nbest,
+ fast_beam_search_nbest_LG,
+ fast_beam_search_nbest_oracle,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from lhotse.cut import Cut
+from multi_dataset import MultiDataset
+from train import add_model_arguments, get_model, get_params
+
+from icefall import byte_encode, smart_byte_decode, tokenize_by_CJK_char
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bbpe_2000/bbpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bbpe_2000",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ - fast_beam_search_nbest
+ - fast_beam_search_nbest_oracle
+ - fast_beam_search_nbest_LG
+ If you use fast_beam_search_nbest_LG, you have to specify
+ `--lang-dir`, which should contain `LG.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=20.0,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search,
+ fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle
+ """,
+ )
+
+ parser.add_argument(
+ "--ngram-lm-scale",
+ type=float,
+ default=0.01,
+ help="""
+ Used only when --decoding_method is fast_beam_search_nbest_LG.
+ It specifies the scale for n-gram LM scores.
+ """,
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=64,
+ help="""Used only when --decoding-method is
+ fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
+ and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--num-paths",
+ type=int,
+ default=200,
+ help="""Number of paths for nbest decoding.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--nbest-scale",
+ type=float,
+ default=0.5,
+ help="""Scale applied to lattice scores when computing nbest paths.
+ Used only when the decoding method is fast_beam_search_nbest,
+ fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
+ )
+
+ parser.add_argument(
+ "--use-tal-csasr",
+ type=str2bool,
+ default=False,
+ help="Whether to use TAL-CSASR training data.",
+ )
+
+ parser.add_argument(
+ "--use-librispeech",
+ type=str2bool,
+ default=False,
+ help="Whether to use LibriSpeech training data.",
+ )
+
+ parser.add_argument(
+ "--use-aishell2",
+ type=str2bool,
+ default=False,
+ help="Whether to use Aishell-2 training data.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ if params.causal:
+ # this seems to cause insertions at the end of the utterance if used with zipformer.
+ pad_len = 30
+ feature_lens += pad_len
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, pad_len),
+ value=LOG_EPS,
+ )
+
+ encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens)
+
+ hyps = []
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "fast_beam_search_nbest_LG":
+ hyp_tokens = fast_beam_search_nbest_LG(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in hyp_tokens:
+ hyps.append([word_table[i] for i in hyp])
+ elif params.decoding_method == "fast_beam_search_nbest":
+ hyp_tokens = fast_beam_search_nbest(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "fast_beam_search_nbest_oracle":
+ hyp_tokens = fast_beam_search_nbest_oracle(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ num_paths=params.num_paths,
+ ref_texts=sp.encode(
+ byte_encode(tokenize_by_CJK_char(supervisions["text"]))
+ ),
+ nbest_scale=params.nbest_scale,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(smart_byte_decode(sp.decode(hyp)).split())
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif "fast_beam_search" in params.decoding_method:
+ key = f"beam_{params.beam}_"
+ key += f"max_contexts_{params.max_contexts}_"
+ key += f"max_states_{params.max_states}"
+ if "nbest" in params.decoding_method:
+ key += f"_num_paths_{params.num_paths}_"
+ key += f"nbest_scale_{params.nbest_scale}"
+ if "LG" in params.decoding_method:
+ key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
+
+ return {key: hyps}
+ else:
+ return {f"beam_size_{params.beam_size}": hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ word_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ word_table:
+ The word symbol table.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
+ fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ texts = [tokenize_by_CJK_char(str(text)).split() for text in texts]
+ # print(texts)
+ # exit()
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ decoding_graph=decoding_graph,
+ word_table=word_table,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ this_batch.append((cut_id, ref_text, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "fast_beam_search_nbest",
+ "fast_beam_search_nbest_LG",
+ "fast_beam_search_nbest_oracle",
+ "modified_beam_search",
+ )
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+ params.suffix += f"-chunk-{params.chunk_size}"
+ params.suffix += f"-left-context-{params.left_context_frames}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ if "nbest" in params.decoding_method:
+ params.suffix += f"-nbest-scale-{params.nbest_scale}"
+ params.suffix += f"-num-paths-{params.num_paths}"
+ if "LG" in params.decoding_method:
+ params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ if "fast_beam_search" in params.decoding_method:
+ if params.decoding_method == "fast_beam_search_nbest_LG":
+ lexicon = Lexicon(params.lang_dir)
+ word_table = lexicon.word_table
+ lg_filename = params.lang_dir / "LG.pt"
+ logging.info(f"Loading {lg_filename}")
+ decoding_graph = k2.Fsa.from_dict(
+ torch.load(lg_filename, map_location=device)
+ )
+ decoding_graph.scores *= params.ngram_lm_scale
+ else:
+ word_table = None
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+ word_table = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ data_module = AsrDataModule(args)
+ multi_dataset = MultiDataset(args)
+
+ def remove_short_utt(c: Cut):
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ if T <= 0:
+ logging.warning(
+ f"Excluding cut with ID: {c.id} from decoding, num_frames: {c.num_frames}"
+ )
+ return T > 0
+
+ test_sets_cuts = multi_dataset.test_cuts()
+
+ test_sets = test_sets_cuts.keys()
+ test_dl = [
+ data_module.test_dataloaders(test_sets_cuts[cuts_name].filter(remove_short_utt))
+ for cuts_name in test_sets
+ ]
+
+ for test_set, test_dl in zip(test_sets, test_dl):
+ logging.info(f"Start decoding test set: {test_set}")
+
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ word_table=word_table,
+ decoding_graph=decoding_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/multi_zh_en/ASR/zipformer/decoder.py b/egs/multi_zh_en/ASR/zipformer/decoder.py
new file mode 120000
index 000000000..5a8018680
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/decoder.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/encoder_interface.py b/egs/multi_zh_en/ASR/zipformer/encoder_interface.py
new file mode 120000
index 000000000..c2eaca671
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/encoder_interface.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py b/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py
new file mode 120000
index 000000000..2962eb784
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/export-onnx-streaming.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx-streaming.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/export-onnx.py b/egs/multi_zh_en/ASR/zipformer/export-onnx.py
new file mode 120000
index 000000000..70a15683c
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/export-onnx.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/export-onnx.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/export.py b/egs/multi_zh_en/ASR/zipformer/export.py
new file mode 100755
index 000000000..fbd9ce0dd
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/export.py
@@ -0,0 +1,541 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+Note: This is a example for librispeech dataset, if you are using different
+dataset, you should change the argument values according to your dataset.
+
+(1) Export to torchscript model using torch.jit.script()
+
+- For non-streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp \
+ --tokens data/lang_bbpe_2000/tokens.txt \
+ --epoch 20 \
+ --avg 1 \
+ --jit 1
+
+It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("jit_script.pt")`.
+
+Check ./jit_pretrained.py for its usage.
+
+Check https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+- For streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --tokens data/lang_bbpe_2000/tokens.txt \
+ --epoch 20 \
+ --avg 1 \
+ --jit 1
+
+It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
+You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
+
+Check ./jit_pretrained_streaming.py for its usage.
+
+Check https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+- For non-streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp \
+ --tokens data/lang_bbpe_2000/tokens.txt \
+ --epoch 20 \
+ --avg 1
+
+- For streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp \
+ --causal 1 \
+ --tokens data/lang_bbpe_2000/tokens.txt \
+ --epoch 20 \
+ --avg 1
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+- For non-streaming model:
+
+To use the generated file with `zipformer/decode.py`,
+you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/librispeech/ASR
+ ./zipformer/decode.py \
+ --exp-dir ./zipformer/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bbpe_2000/bpe.model
+
+- For streaming model:
+
+To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/librispeech/ASR
+
+ # simulated streaming decoding
+ ./zipformer/decode.py \
+ --exp-dir ./zipformer/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bbpe_2000/bpe.model
+
+ # chunk-wise streaming decoding
+ ./zipformer/streaming_decode.py \
+ --exp-dir ./zipformer/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bbpe_2000/bpe.model
+
+Check ./pretrained.py for its usage.
+
+Note: If you don't want to train a model from scratch, we have
+provided one for you. You can get it at
+
+- non-streaming model:
+https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/
+
+with the following commands:
+
+ sudo apt-get install git-lfs
+ git lfs install
+ git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2/
+ # You will find the pre-trained models in exp dir
+"""
+
+import argparse
+import logging
+import re
+from pathlib import Path
+from typing import List, Tuple
+
+import k2
+import torch
+from scaling_converter import convert_scaled_to_non_scaled
+from torch import Tensor, nn
+from train import add_model_arguments, get_model, get_params
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import make_pad_mask, str2bool
+
+
+def num_tokens(
+ token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
+) -> int:
+ """Return the number of tokens excluding those from
+ disambiguation symbols.
+
+ Caution:
+ 0 is not a token ID so it is excluded from the return value.
+ """
+ symbols = token_table.symbols
+ ans = []
+ for s in symbols:
+ if not disambig_pattern.match(s):
+ ans.append(token_table[s])
+ num_tokens = len(ans)
+ if 0 in ans:
+ num_tokens -= 1
+ return num_tokens
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=20,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=1,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/lang_bbpe_2000/tokens.txt",
+ help="Path to the tokens.txt",
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ It will generate a file named jit_script.pt.
+ Check ./jit_pretrained.py for how to use it.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+class EncoderModel(nn.Module):
+ """A wrapper for encoder and encoder_embed"""
+
+ def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
+ super().__init__()
+ self.encoder = encoder
+ self.encoder_embed = encoder_embed
+
+ def forward(
+ self, features: Tensor, feature_lengths: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ features: (N, T, C)
+ feature_lengths: (N,)
+ """
+ x, x_lens = self.encoder_embed(features, feature_lengths)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ return encoder_out, encoder_out_lens
+
+
+class StreamingEncoderModel(nn.Module):
+ """A wrapper for encoder and encoder_embed"""
+
+ def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
+ super().__init__()
+ assert len(encoder.chunk_size) == 1, encoder.chunk_size
+ assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
+ self.chunk_size = encoder.chunk_size[0]
+ self.left_context_len = encoder.left_context_frames[0]
+
+ # The encoder_embed subsample features (T - 7) // 2
+ # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
+ self.pad_length = 7 + 2 * 3
+
+ self.encoder = encoder
+ self.encoder_embed = encoder_embed
+
+ def forward(
+ self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
+ ) -> Tuple[Tensor, Tensor, List[Tensor]]:
+ """Streaming forward for encoder_embed and encoder.
+
+ Args:
+ features: (N, T, C)
+ feature_lengths: (N,)
+ states: a list of Tensors
+
+ Returns encoder outputs, output lengths, and updated states.
+ """
+ chunk_size = self.chunk_size
+ left_context_len = self.left_context_len
+
+ cached_embed_left_pad = states[-2]
+ x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
+ x=features,
+ x_lens=feature_lengths,
+ cached_left_pad=cached_embed_left_pad,
+ )
+ assert x.size(1) == chunk_size, (x.size(1), chunk_size)
+
+ src_key_padding_mask = make_pad_mask(x_lens)
+
+ # processed_mask is used to mask out initial states
+ processed_mask = torch.arange(left_context_len, device=x.device).expand(
+ x.size(0), left_context_len
+ )
+ processed_lens = states[-1] # (batch,)
+ # (batch, left_context_size)
+ processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
+ # Update processed lengths
+ new_processed_lens = processed_lens + x_lens
+
+ # (batch, left_context_size + chunk_size)
+ src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
+
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+ encoder_states = states[:-2]
+
+ (
+ encoder_out,
+ encoder_out_lens,
+ new_encoder_states,
+ ) = self.encoder.streaming_forward(
+ x=x,
+ x_lens=x_lens,
+ states=encoder_states,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+
+ new_states = new_encoder_states + [
+ new_cached_embed_left_pad,
+ new_processed_lens,
+ ]
+ return encoder_out, encoder_out_lens, new_states
+
+ @torch.jit.export
+ def get_init_states(
+ self,
+ batch_size: int = 1,
+ device: torch.device = torch.device("cpu"),
+ ) -> List[torch.Tensor]:
+ """
+ Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+ is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+ states[-2] is the cached left padding for ConvNeXt module,
+ of shape (batch_size, num_channels, left_pad, num_freqs)
+ states[-1] is processed_lens of shape (batch,), which records the number
+ of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
+ """
+ states = self.encoder.get_init_states(batch_size, device)
+
+ embed_states = self.encoder_embed.get_init_states(batch_size, device)
+ states.append(embed_states)
+
+ processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
+ states.append(processed_lens)
+
+ return states
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ # if torch.cuda.is_available():
+ # device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ token_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = token_table[""]
+ params.vocab_size = num_tokens(token_table) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.eval()
+
+ if params.jit is True:
+ convert_scaled_to_non_scaled(model, inplace=True)
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+
+ # Wrap encoder and encoder_embed as a module
+ if params.causal:
+ model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
+ chunk_size = model.encoder.chunk_size
+ left_context_len = model.encoder.left_context_len
+ filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
+ else:
+ model.encoder = EncoderModel(model.encoder, model.encoder_embed)
+ filename = "jit_script.pt"
+
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ model.save(str(params.exp_dir / filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torchscript. Export model.state_dict()")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py b/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py
new file mode 100755
index 000000000..68111fad7
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/generate_averaged_model.py
@@ -0,0 +1,193 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) use the checkpoint exp_dir/epoch-xxx.pt
+./zipformer/generate_averaged_model.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./zipformer/exp
+
+It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
+You can later load it by `torch.load("epoch-28-avg-15.pt")`.
+
+(2) use the checkpoint exp_dir/checkpoint-iter.pt
+./zipformer/generate_averaged_model.py \
+ --iter 22000 \
+ --avg 5 \
+ --exp-dir ./zipformer/exp
+
+It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
+You can later load it by `torch.load("iter-22000-avg-5.pt")`.
+"""
+
+
+import argparse
+from pathlib import Path
+
+import k2
+import torch
+from train import add_model_arguments, get_model, get_params
+
+from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=9,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/lang_bpe_500/tokens.txt",
+ help="Path to the tokens.txt",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ print("Script started")
+
+ device = torch.device("cpu")
+ print(f"Device: {device}")
+
+ symbol_table = k2.SymbolTable.from_file(params.tokens)
+ params.blank_id = symbol_table[""]
+ params.unk_id = symbol_table[""]
+ params.vocab_size = len(symbol_table)
+
+ print("About to create model")
+ model = get_model(params)
+
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ print(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
+ torch.save({"model": model.state_dict()}, filename)
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ print(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
+ torch.save({"model": model.state_dict()}, filename)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ print(f"Number of model parameters: {num_param}")
+
+ print("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py
new file mode 120000
index 000000000..25108391f
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/jit_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/jit_pretrained.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py
new file mode 120000
index 000000000..9a8da5844
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_ctc.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py
new file mode 120000
index 000000000..1962351e9
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/jit_pretrained_streaming.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/joiner.py b/egs/multi_zh_en/ASR/zipformer/joiner.py
new file mode 120000
index 000000000..5b8a36332
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/joiner.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/model.py b/egs/multi_zh_en/ASR/zipformer/model.py
new file mode 120000
index 000000000..cd7e07d72
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/model.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/multi_dataset.py b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py
new file mode 100644
index 000000000..1155a3dcc
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/multi_dataset.py
@@ -0,0 +1,247 @@
+# Copyright 2023 Xiaomi Corp. (authors: Zengrui Jin)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Dict
+
+from lhotse import CutSet, load_manifest_lazy
+
+
+class MultiDataset:
+ def __init__(self, args: argparse.Namespace):
+ """
+ Args:
+ manifest_dir:
+ It is expected to contain the following files:
+ - aishell2_cuts_train.jsonl.gz
+ """
+ self.fbank_dir = Path(args.manifest_dir)
+ self.use_tal_csasr = args.use_tal_csasr
+ self.use_librispeech = args.use_librispeech
+ self.use_aishell2 = args.use_aishell2
+
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get multidataset train cuts")
+
+ # AISHELL-2
+ if self.use_aishell2:
+ logging.info("Loading Aishell-2 in lazy mode")
+ aishell_2_cuts = load_manifest_lazy(
+ self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
+ )
+
+ # TAL-CSASR
+ if self.use_tal_csasr:
+ logging.info("Loading TAL-CSASR in lazy mode")
+ tal_csasr_cuts = load_manifest_lazy(
+ self.fbank_dir / "tal_csasr_cuts_train_set.jsonl.gz"
+ )
+
+ # LibriSpeech
+ if self.use_librispeech:
+ logging.info("Loading LibriSpeech in lazy mode")
+ train_clean_100_cuts = self.train_clean_100_cuts()
+ train_clean_360_cuts = self.train_clean_360_cuts()
+ train_other_500_cuts = self.train_other_500_cuts()
+
+ if self.use_tal_csasr and self.use_librispeech and self.use_aishell2:
+ return CutSet.mux(
+ aishell_2_cuts,
+ train_clean_100_cuts,
+ train_clean_360_cuts,
+ train_other_500_cuts,
+ tal_csasr_cuts,
+ weights=[
+ len(aishell_2_cuts),
+ len(train_clean_100_cuts),
+ len(train_clean_360_cuts),
+ len(train_other_500_cuts),
+ len(tal_csasr_cuts),
+ ],
+ )
+ elif not self.use_tal_csasr and self.use_librispeech and self.use_aishell2:
+ return CutSet.mux(
+ aishell_2_cuts,
+ train_clean_100_cuts,
+ train_clean_360_cuts,
+ train_other_500_cuts,
+ weights=[
+ len(aishell_2_cuts),
+ len(train_clean_100_cuts),
+ len(train_clean_360_cuts),
+ len(train_other_500_cuts),
+ ],
+ )
+ elif self.use_tal_csasr and not self.use_librispeech and self.use_aishell2:
+ return CutSet.mux(
+ aishell_2_cuts,
+ tal_csasr_cuts,
+ weights=[
+ len(aishell_2_cuts),
+ len(tal_csasr_cuts),
+ ],
+ )
+ elif self.use_tal_csasr and self.use_librispeech and not self.use_aishell2:
+ return CutSet.mux(
+ train_clean_100_cuts,
+ train_clean_360_cuts,
+ train_other_500_cuts,
+ tal_csasr_cuts,
+ weights=[
+ len(train_clean_100_cuts),
+ len(train_clean_360_cuts),
+ len(train_other_500_cuts),
+ len(tal_csasr_cuts),
+ ],
+ )
+ else:
+ raise NotImplementedError(
+ f"""Not implemented for
+ use_aishell2: {self.use_aishell2}
+ use_librispeech: {self.use_librispeech}
+ use_tal_csasr: {self.use_tal_csasr}"""
+ )
+
+ def dev_cuts(self) -> CutSet:
+ logging.info("About to get multidataset dev cuts")
+
+ # AISHELL-2
+ logging.info("Loading Aishell-2 DEV set in lazy mode")
+ aishell2_dev_cuts = load_manifest_lazy(
+ self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
+ )
+
+ # LibriSpeech
+ dev_clean_cuts = self.dev_clean_cuts()
+ dev_other_cuts = self.dev_other_cuts()
+
+ logging.info("Loading TAL-CSASR set in lazy mode")
+ tal_csasr_dev_cuts = load_manifest_lazy(
+ self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz"
+ )
+
+ return CutSet.mux(
+ aishell2_dev_cuts,
+ dev_clean_cuts,
+ dev_other_cuts,
+ tal_csasr_dev_cuts,
+ weights=[
+ len(aishell2_dev_cuts),
+ len(dev_clean_cuts),
+ len(dev_other_cuts),
+ len(tal_csasr_dev_cuts),
+ ],
+ )
+
+ def test_cuts(self) -> Dict[str, CutSet]:
+ logging.info("About to get multidataset test cuts")
+
+ # AISHELL-2
+ if self.use_aishell2:
+ logging.info("Loading Aishell-2 set in lazy mode")
+ aishell2_test_cuts = load_manifest_lazy(
+ self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
+ )
+ aishell2_dev_cuts = load_manifest_lazy(
+ self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
+ )
+
+ # LibriSpeech
+ if self.use_librispeech:
+ test_clean_cuts = self.test_clean_cuts()
+ test_other_cuts = self.test_other_cuts()
+
+ logging.info("Loading TAL-CSASR set in lazy mode")
+ tal_csasr_test_cuts = load_manifest_lazy(
+ self.fbank_dir / "tal_csasr_cuts_test_set.jsonl.gz"
+ )
+ tal_csasr_dev_cuts = load_manifest_lazy(
+ self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz"
+ )
+
+ test_cuts = {
+ "tal_csasr_test": tal_csasr_test_cuts,
+ "tal_csasr_dev": tal_csasr_dev_cuts,
+ }
+
+ if self.use_aishell2:
+ test_cuts.update(
+ {
+ "aishell-2_test": aishell2_test_cuts,
+ "aishell-2_dev": aishell2_dev_cuts,
+ }
+ )
+ if self.use_librispeech:
+ test_cuts.update(
+ {
+ "librispeech_test_clean": test_clean_cuts,
+ "librispeech_test_other": test_other_cuts,
+ }
+ )
+ return test_cuts
+
+ @lru_cache()
+ def train_clean_100_cuts(self) -> CutSet:
+ logging.info("About to get train-clean-100 cuts")
+ return load_manifest_lazy(
+ self.fbank_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_clean_360_cuts(self) -> CutSet:
+ logging.info("About to get train-clean-360 cuts")
+ return load_manifest_lazy(
+ self.fbank_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_other_500_cuts(self) -> CutSet:
+ logging.info("About to get train-other-500 cuts")
+ return load_manifest_lazy(
+ self.fbank_dir / "librispeech_cuts_train-other-500.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_clean_cuts(self) -> CutSet:
+ logging.info("About to get dev-clean cuts")
+ return load_manifest_lazy(
+ self.fbank_dir / "librispeech_cuts_dev-clean.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_other_cuts(self) -> CutSet:
+ logging.info("About to get dev-other cuts")
+ return load_manifest_lazy(
+ self.fbank_dir / "librispeech_cuts_dev-other.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_clean_cuts(self) -> CutSet:
+ logging.info("About to get test-clean cuts")
+ return load_manifest_lazy(
+ self.fbank_dir / "librispeech_cuts_test-clean.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_other_cuts(self) -> CutSet:
+ logging.info("About to get test-other cuts")
+ return load_manifest_lazy(
+ self.fbank_dir / "librispeech_cuts_test-other.jsonl.gz"
+ )
diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_check.py b/egs/multi_zh_en/ASR/zipformer/onnx_check.py
new file mode 120000
index 000000000..f3dd42004
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/onnx_check.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_check.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_decode.py b/egs/multi_zh_en/ASR/zipformer/onnx_decode.py
new file mode 120000
index 000000000..0573b88c5
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/onnx_decode.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_decode.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py
new file mode 120000
index 000000000..cfea104c2
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained-streaming.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py
new file mode 120000
index 000000000..8f32f4ee7
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/onnx_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/onnx_pretrained.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/optim.py b/egs/multi_zh_en/ASR/zipformer/optim.py
new file mode 120000
index 000000000..5eaa3cffd
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/optim.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/pretrained.py b/egs/multi_zh_en/ASR/zipformer/pretrained.py
new file mode 100755
index 000000000..676272e1f
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/pretrained.py
@@ -0,0 +1,378 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a checkpoint and uses it to decode waves.
+You can generate the checkpoint with the following command:
+
+Note: This is a example for librispeech dataset, if you are using different
+dataset, you should change the argument values according to your dataset.
+
+- For non-streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp \
+ --tokens data/lang_bbpe_2000/tokens.txt \
+ --epoch 23 \
+ --avg 1
+
+- For streaming model:
+
+./zipformer/export.py \
+ --exp-dir ./zipformer/exp \
+ --causal 1 \
+ --tokens data/lang_bbpe_2000/tokens.txt \
+ --epoch 23 \
+ --avg 1
+
+Usage of this script:
+
+- For non-streaming model:
+
+(1) greedy search
+./zipformer/pretrained.py \
+ --checkpoint ./zipformer/exp/pretrained.pt \
+ --tokens data/lang_bbpe_2000/tokens.txt \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) modified beam search
+./zipformer/pretrained.py \
+ --checkpoint ./zipformer/exp/pretrained.pt \
+ --tokens ./data/lang_bbpe_2000/tokens.txt \
+ --method modified_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) fast beam search
+./zipformer/pretrained.py \
+ --checkpoint ./zipformer/exp/pretrained.pt \
+ --tokens ./data/lang_bbpe_2000/tokens.txt \
+ --method fast_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+- For streaming model:
+
+(1) greedy search
+./zipformer/pretrained.py \
+ --checkpoint ./zipformer/exp/pretrained.pt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --tokens ./data/lang_bbpe_2000/tokens.txt \
+ --method greedy_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(2) modified beam search
+./zipformer/pretrained.py \
+ --checkpoint ./zipformer/exp/pretrained.pt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --tokens ./data/lang_bbpe_2000/tokens.txt \
+ --method modified_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+(3) fast beam search
+./zipformer/pretrained.py \
+ --checkpoint ./zipformer/exp/pretrained.pt \
+ --causal 1 \
+ --chunk-size 16 \
+ --left-context-frames 128 \
+ --tokens ./data/lang_bbpe_2000/tokens.txt \
+ --method fast_beam_search \
+ /path/to/foo.wav \
+ /path/to/bar.wav
+
+
+You can also use `./zipformer/exp/epoch-xx.pt`.
+
+Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py
+"""
+
+
+import argparse
+import logging
+import math
+from typing import List
+
+import k2
+import kaldifeat
+import sentencepiece as spm
+import torch
+import torchaudio
+from beam_search import (
+ fast_beam_search_one_best,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from export import num_tokens
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_model, get_params
+
+from icefall import smart_byte_decode
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ required=True,
+ help="Path to the checkpoint. "
+ "The checkpoint is assumed to be saved by "
+ "icefall.checkpoint.save_checkpoint().",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ help="""Path to byte-level bpe model.""",
+ )
+
+ parser.add_argument(
+ "--method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "sound_files",
+ type=str,
+ nargs="+",
+ help="The input sound file(s) to transcribe. "
+ "Supported formats are those supported by torchaudio.load(). "
+ "For example, wav and flac are supported. "
+ "The sample rate has to be 16kHz.",
+ )
+
+ parser.add_argument(
+ "--sample-rate",
+ type=int,
+ default=16000,
+ help="The sample rate of the input sound file",
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame. Used only when
+ --method is greedy_search.
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def read_sound_files(
+ filenames: List[str], expected_sample_rate: float
+) -> List[torch.Tensor]:
+ """Read a list of sound files into a list 1-D float32 torch tensors.
+ Args:
+ filenames:
+ A list of sound filenames.
+ expected_sample_rate:
+ The expected sample rate of the sound files.
+ Returns:
+ Return a list of 1-D float32 torch tensors.
+ """
+ ans = []
+ for f in filenames:
+ wave, sample_rate = torchaudio.load(f)
+ assert (
+ sample_rate == expected_sample_rate
+ ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
+ # We use only the first channel
+ ans.append(wave[0].contiguous())
+ return ans
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+
+ params.update(vars(args))
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(f"{params}")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ if params.causal:
+ assert (
+ "," not in params.chunk_size
+ ), "chunk_size should be one value in decoding."
+ assert (
+ "," not in params.left_context_frames
+ ), "left_context_frames should be one value in decoding."
+
+ logging.info("Creating model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ checkpoint = torch.load(args.checkpoint, map_location="cpu")
+ model.load_state_dict(checkpoint["model"], strict=False)
+ model.to(device)
+ model.eval()
+
+ logging.info("Constructing Fbank computer")
+ opts = kaldifeat.FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = params.sample_rate
+ opts.mel_opts.num_bins = params.feature_dim
+
+ fbank = kaldifeat.Fbank(opts)
+
+ logging.info(f"Reading sound files: {params.sound_files}")
+ waves = read_sound_files(
+ filenames=params.sound_files, expected_sample_rate=params.sample_rate
+ )
+ waves = [w.to(device) for w in waves]
+
+ logging.info("Decoding started")
+ features = fbank(waves)
+ feature_lengths = [f.size(0) for f in features]
+
+ features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
+ feature_lengths = torch.tensor(feature_lengths, device=device)
+
+ # model forward
+ encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
+
+ hyps = []
+ msg = f"Using {params.method}"
+ logging.info(msg)
+
+ if params.method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(smart_byte_decode(hyp).split())
+ else:
+ raise ValueError(f"Unsupported method: {params.method}")
+
+ s = "\n"
+ for filename, hyp in zip(params.sound_files, hyps):
+ s += f"{filename}:\n{hyp}\n\n"
+ logging.info(s)
+
+ logging.info("Decoding Done")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/multi_zh_en/ASR/zipformer/scaling.py b/egs/multi_zh_en/ASR/zipformer/scaling.py
new file mode 120000
index 000000000..6f398f431
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/scaling_converter.py b/egs/multi_zh_en/ASR/zipformer/scaling_converter.py
new file mode 120000
index 000000000..b0ecee05e
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/scaling_converter.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py b/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py
new file mode 120000
index 000000000..b1ed54557
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/streaming_beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/streaming_beam_search.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/streaming_decode.py b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py
new file mode 120000
index 000000000..13fd02a78
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/streaming_decode.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/streaming_decode.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/subsampling.py b/egs/multi_zh_en/ASR/zipformer/subsampling.py
new file mode 120000
index 000000000..01ae9002c
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/subsampling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/subsampling.py
\ No newline at end of file
diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py
new file mode 100755
index 000000000..310c8fe59
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/train.py
@@ -0,0 +1,1416 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,
+# Zengwei Yao,
+# Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+# For non-streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --max-duration 1000
+
+# For streaming model training:
+./zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir zipformer/exp \
+ --causal 1 \
+ --max-duration 1000
+
+It supports training with:
+ - transducer loss (default), with `--use-transducer True --use-ctc False`
+ - ctc loss (not recommended), with `--use-transducer False --use-ctc True`
+ - transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import AsrModel
+from multi_dataset import MultiDataset
+from optim import Eden, ScaledAdam
+from scaling import ScheduledFloat
+from subsampling import Conv2dSubsampling
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer2
+
+from icefall import byte_encode, diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+ tokenize_by_CJK_char,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,3,4,3,2",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--downsampling-factor",
+ type=str,
+ default="1,2,4,8,4,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dim",
+ type=str,
+ default="512,768,1024,1536,1024,768",
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--num-heads",
+ type=str,
+ default="4,4,4,8,4,4",
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--encoder-dim",
+ type=str,
+ default="192,256,384,512,384,256",
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=str,
+ default="32",
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=str,
+ default="12",
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=str,
+ default="4",
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default="48",
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dim",
+ type=str,
+ default="192,192,256,256,256,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "A single int or comma-separated list. Must be <= each corresponding encoder_dim.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernel",
+ type=str,
+ default="31,31,15,15,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--causal",
+ type=str2bool,
+ default=False,
+ help="If True, use causal version of model.",
+ )
+
+ parser.add_argument(
+ "--chunk-size",
+ type=str,
+ default="16,32,64,-1",
+ help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. "
+ " Must be just -1 if --causal=False",
+ )
+
+ parser.add_argument(
+ "--left-context-frames",
+ type=str,
+ default="64,128,256,-1",
+ help="Maximum left-contexts for causal training, measured in frames which will "
+ "be converted to a number of chunks. If splitting into chunks, "
+ "chunk left-context frames will be chosen randomly from this list; else not relevant.",
+ )
+
+ parser.add_argument(
+ "--use-transducer",
+ type=str2bool,
+ default=True,
+ help="If True, use Transducer head.",
+ )
+
+ parser.add_argument(
+ "--use-ctc",
+ type=str2bool,
+ default=False,
+ help="If True, use CTC head.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipformer/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bbpe_2000/bbpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.045, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=600,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network)" "part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--use-tal-csasr",
+ type=str2bool,
+ default=False,
+ help="Whether to use TAL-CSASR training data.",
+ )
+
+ parser.add_argument(
+ "--use-librispeech",
+ type=str2bool,
+ default=False,
+ help="Whether to use LibriSpeech training data.",
+ )
+
+ parser.add_argument(
+ "--use-aishell2",
+ type=str2bool,
+ default=False,
+ help="Whether to use Aishell-2 training data.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def _to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_encoder_embed(params: AttributeDict) -> nn.Module:
+ # encoder_embed converts the input of shape (N, T, num_features)
+ # to the shape (N, (T - 7) // 2, encoder_dims).
+ # That is, it does two things simultaneously:
+ # (1) subsampling: T -> (T - 7) // 2
+ # (2) embedding: num_features -> encoder_dims
+ # In the normal configuration, we will downsample once more at the end
+ # by a factor of 2, and most of the encoder stacks will run at a lower
+ # sampling rate.
+ encoder_embed = Conv2dSubsampling(
+ in_channels=params.feature_dim,
+ out_channels=_to_int_tuple(params.encoder_dim)[0],
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ )
+ return encoder_embed
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ encoder = Zipformer2(
+ output_downsampling_factor=2,
+ downsampling_factor=_to_int_tuple(params.downsampling_factor),
+ num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
+ encoder_dim=_to_int_tuple(params.encoder_dim),
+ encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
+ query_head_dim=_to_int_tuple(params.query_head_dim),
+ pos_head_dim=_to_int_tuple(params.pos_head_dim),
+ value_head_dim=_to_int_tuple(params.value_head_dim),
+ pos_dim=params.pos_dim,
+ num_heads=_to_int_tuple(params.num_heads),
+ feedforward_dim=_to_int_tuple(params.feedforward_dim),
+ cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ causal=params.causal,
+ chunk_size=_to_int_tuple(params.chunk_size),
+ left_context_frames=_to_int_tuple(params.left_context_frames),
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ assert params.use_transducer or params.use_ctc, (
+ f"At least one of them should be True, "
+ f"but got params.use_transducer={params.use_transducer}, "
+ f"params.use_ctc={params.use_ctc}"
+ )
+
+ encoder_embed = get_encoder_embed(params)
+ encoder = get_encoder_model(params)
+
+ if params.use_transducer:
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+ else:
+ decoder = None
+ joiner = None
+
+ model = AsrModel(
+ encoder_embed=encoder_embed,
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=max(_to_int_tuple(params.encoder_dim)),
+ decoder_dim=params.decoder_dim,
+ vocab_size=params.vocab_size,
+ use_transducer=params.use_transducer,
+ use_ctc=params.use_ctc,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = sp.encode(texts, out_type=int)
+ y = k2.RaggedTensor(y)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss, ctc_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ loss = 0.0
+
+ if params.use_transducer:
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ if params.use_ctc:
+ loss += params.ctc_loss_scale * ctc_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ if params.use_transducer:
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+ if params.use_ctc:
+ info["ctc_loss"] = ctc_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint_impl(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ save_bad_model()
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ if not params.use_transducer:
+ params.ctc_loss_scale = 1.0
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ data_module = AsrDataModule(args)
+ multi_dataset = MultiDataset(args)
+
+ train_cuts = multi_dataset.train_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 12 seconds
+ #
+ # Caution: There is a reason to select 12.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ T = ((c.num_frames - 7) // 2 + 1) // 2
+ tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ if T < len(tokens):
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. "
+ f"Number of frames (before subsampling): {c.num_frames}. "
+ f"Number of frames (after subsampling): {T}. "
+ f"Text: {c.supervisions[0].text}. "
+ f"Tokens: {tokens}. "
+ f"Number of tokens: {len(tokens)}"
+ )
+ return False
+
+ return True
+
+ def tokenize_and_encode_text(c: Cut):
+ # Text normalize for each sample
+ text = c.supervisions[0].text
+ text = byte_encode(tokenize_by_CJK_char(text))
+ c.supervisions[0].text = text
+ return c
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ train_cuts = train_cuts.map(tokenize_and_encode_text)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = data_module.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = multi_dataset.dev_cuts()
+ valid_dl = data_module.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ sp=sp,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = sp.encode(supervisions["text"], out_type=int)
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ sp: spm.SentencePieceProcessor,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ AsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/multi_zh_en/ASR/zipformer/zipformer.py b/egs/multi_zh_en/ASR/zipformer/zipformer.py
new file mode 120000
index 000000000..23011dda7
--- /dev/null
+++ b/egs/multi_zh_en/ASR/zipformer/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/zipformer/zipformer.py
\ No newline at end of file
diff --git a/egs/voxpopuli/ASR/README.md b/egs/voxpopuli/ASR/README.md
new file mode 100644
index 000000000..92aa26464
--- /dev/null
+++ b/egs/voxpopuli/ASR/README.md
@@ -0,0 +1,38 @@
+# Readme
+
+This recipe contains data preparation for the
+[VoxPopuli](https://github.com/facebookresearch/voxpopuli) dataset
+[(pdf)](https://aclanthology.org/2021.acl-long.80.pdf).
+At the moment, without model training.
+
+
+## audio per language
+
+| language | Size | Hrs. untranscribed | Hrs. transcribed |
+|----------|--------|--------------------|------------------|
+| bg | 295G | 17.6K | - |
+| cs | 308G | 18.7K | 62 |
+| da | 233G | 13.6K | - |
+| de | 379G | 23.2K | 282 |
+| el | 305G | 17.7K | - |
+| en | 382G | 24.1K | 543 |
+| es | 362G | 21.4K | 166 |
+| et | 179G | 10.6K | 3 |
+| fi | 236G | 14.2K | 27 |
+| fr | 376G | 22.8K | 211 |
+| hr | 132G | 8.1K | 43 |
+| hu | 297G | 17.7K | 63 |
+| it | 361G | 21.9K | 91 |
+| lt | 243G | 14.4K | 2 |
+| lv | 217G | 13.1K | - |
+| mt | 147G | 9.1K | - |
+| nl | 322G | 19.0K | 53 |
+| pl | 348G | 21.2K | 111 |
+| pt | 300G | 17.5K | - |
+| ro | 296G | 17.9K | 89 |
+| sk | 201G | 12.1K | 35 |
+| sl | 190G | 11.3K | 10 |
+| sv | 272G | 16.3K | - |
+| | | | |
+| total | 6.3T | 384K | 1791 |
+
diff --git a/egs/voxpopuli/ASR/local/compute_fbank.py b/egs/voxpopuli/ASR/local/compute_fbank.py
new file mode 100755
index 000000000..b63e51f29
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/compute_fbank.py
@@ -0,0 +1,248 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+# 2023 Brno University of Technology (authors: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of VoxPopuli dataset.
+
+Usage example:
+
+ python3 ./local/compute_fbank.py \
+ --src-dir data/fbank --output-dir data/fbank \
+ --num-jobs 100 --num-workers 25 \
+ --prefix "voxpopuli-${task}-${lang}" \
+ --dataset train \
+ --trim-to-supervisions True \
+ --speed-perturb True
+
+It looks for raw CutSet in the directory data/fbank
+located at: `{src_dir}/{prefix}_cuts_{dataset}_raw.jsonl.gz`.
+
+The generated fbank features are saved in `data/fbank/{prefix}-{dataset}_feats`
+and CutSet manifest stored in `data/fbank/{prefix}_cuts_{dataset}.jsonl.gz`.
+
+Typically, the number of workers is smaller than number of jobs
+(see --num-jobs 100 --num-workers 25 in the example).
+And, the number of jobs should be at least the number of workers (it's checked).
+"""
+
+import argparse
+import logging
+import multiprocessing
+import os
+from concurrent.futures import ProcessPoolExecutor
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+from filter_cuts import filter_cuts
+from lhotse import (
+ CutSet,
+ Fbank,
+ FbankConfig,
+ LilcomChunkyWriter,
+ is_caching_enabled,
+ set_caching_enabled,
+)
+
+from icefall.utils import str2bool
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ help="""Path to the bpe.model. If not None, we will remove short and
+ long utterances before extracting features""",
+ )
+ parser.add_argument(
+ "--src-dir",
+ type=str,
+ help="""Folder with the input manifest files.""",
+ default="data/manifests",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ help="""Folder with the output manifests (cuts) and feature files.""",
+ default="data/fbank",
+ )
+
+ parser.add_argument(
+ "--prefix",
+ type=str,
+ help="""Prefix of the manifest files.""",
+ default="",
+ )
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ help="""Dataset parts to compute fbank (train,test,dev).""",
+ default=None,
+ )
+
+ parser.add_argument(
+ "--num-jobs",
+ type=int,
+ help="""Number of jobs (i.e. files with extracted features)""",
+ default=50,
+ )
+ parser.add_argument(
+ "--num-workers",
+ type=int,
+ help="""Number of parallel workers""",
+ default=10,
+ )
+ parser.add_argument(
+ "--speed-perturb",
+ type=str2bool,
+ default=False,
+ help="""Enable speed perturbation for the set.""",
+ )
+ parser.add_argument(
+ "--trim-to-supervisions",
+ type=str2bool,
+ default=False,
+ help="""Apply `trim-to-supervision` to cut set.""",
+ )
+
+ return parser.parse_args()
+
+
+def compute_fbank_features(args: argparse.Namespace):
+ set_caching_enabled(True) # lhotse
+
+ src_dir = Path(args.src_dir)
+ output_dir = Path(args.output_dir)
+ num_jobs = args.num_jobs
+ num_workers = min(args.num_workers, os.cpu_count())
+ num_mel_bins = 80
+
+ bpe_model = args.bpe_model
+ if bpe_model:
+ logging.info(f"Loading {bpe_model}")
+ sp = spm.SentencePieceProcessor()
+ sp.load(bpe_model)
+
+ prefix = args.prefix # "ELEF_TRAIN"
+ dataset = args.dataset
+ suffix = "jsonl.gz"
+
+ cuts_raw_filename = Path(f"{src_dir}/{prefix}_cuts_{dataset}_raw.{suffix}")
+ cuts_raw = CutSet.from_file(cuts_raw_filename)
+
+ extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
+
+ cuts_filename = Path(f"{prefix}_cuts_{dataset}.{suffix}")
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{output_dir/cuts_filename} already exists - skipping.")
+ return
+
+ logging.info(f"Processing {output_dir/cuts_filename}")
+ cut_set = cuts_raw
+
+ if bpe_model:
+ cut_set = filter_cuts(cut_set, sp)
+
+ if args.speed_perturb:
+ cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+
+ if args.trim_to_supervisions:
+ logging.info(f"About to `trim_to_supervisions()` {output_dir / cuts_filename}")
+ cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
+ else:
+ logging.info(
+ "Not doing `trim_to_supervisions()`, "
+ "to enable use --trim-to-supervision=True"
+ )
+
+ cut_set = cut_set.to_eager() # disallow lazy evaluation (sorting requires it)
+ cut_set = cut_set.sort_by_recording_id() # enhances AudioCache hit rate
+
+ # We typically use `num_jobs=100, num_workers=20`
+ # - this is helpful for large databases
+ # - both values are configurable externally
+ assert num_jobs >= num_workers, (num_jobs, num_workers)
+ executor = ProcessPoolExecutor(
+ max_workers=num_workers,
+ mp_context=multiprocessing.get_context("spawn"),
+ initializer=set_caching_enabled,
+ initargs=(is_caching_enabled(),),
+ )
+
+ logging.info(
+ f"executor {executor} : num_workers {num_workers}, num_jobs {num_jobs}"
+ )
+
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir / prefix}-{dataset}_feats",
+ num_jobs=num_jobs,
+ executor=executor,
+ storage_type=LilcomChunkyWriter,
+ )
+
+ # correct small deviations of duration, caused by speed-perturbation
+ for cut in cut_set:
+ assert len(cut.supervisions) == 1, (len(cut.supervisions), cut.id)
+ duration_difference = abs(cut.supervisions[0].duration - cut.duration)
+ tolerance = 0.02 # 20ms
+ if duration_difference == 0.0:
+ pass
+ elif duration_difference <= tolerance:
+ logging.info(
+ "small mismatch of the supervision duration "
+ f"(Δt = {duration_difference*1000}ms), "
+ f"correcting : cut.duration {cut.duration} -> "
+ f"supervision {cut.supervisions[0].duration}"
+ )
+ cut.supervisions[0].duration = cut.duration
+ else:
+ logging.error(
+ "mismatch of cut/supervision duration "
+ f"(Δt = {duration_difference*1000}ms) : "
+ f"cut.duration {cut.duration}, "
+ f"supervision {cut.supervisions[0].duration}"
+ )
+ raise ValueError(
+ "mismatch of cut/supervision duration "
+ f"(Δt = {duration_difference*1000}ms)"
+ )
+
+ # store the cutset
+ logging.info(f"storing CutSet to : `{output_dir / cuts_filename}`")
+ cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+ logging.info(vars(args))
+
+ compute_fbank_features(args)
diff --git a/egs/voxpopuli/ASR/local/compute_fbank_musan.py b/egs/voxpopuli/ASR/local/compute_fbank_musan.py
new file mode 120000
index 000000000..5833f2484
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/compute_fbank_musan.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/compute_fbank_musan.py
\ No newline at end of file
diff --git a/egs/voxpopuli/ASR/local/display_manifest_statistics.py b/egs/voxpopuli/ASR/local/display_manifest_statistics.py
new file mode 100755
index 000000000..36c99e126
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/display_manifest_statistics.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+# 2023 Brno University of Technology (authors: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+
+Usage example:
+ python3 ./local/display_manifest_statistics.py data/fbank/*_cuts*.jsonl.gz
+
+See the function `remove_short_and_long_utt()` in transducer/train.py
+for usage.
+
+"""
+
+import argparse
+
+from lhotse import load_manifest_lazy
+
+
+def get_args():
+ parser = argparse.ArgumentParser("Compute statistics for 'cuts' .jsonl.gz")
+
+ parser.add_argument(
+ "filename",
+ help="data/fbank/imported_cuts_bison-train_trim.jsonl.gz",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ cuts = load_manifest_lazy(args.filename)
+ cuts.describe()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py b/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py
new file mode 100755
index 000000000..957267fe8
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py
@@ -0,0 +1,93 @@
+#!/usr/bin/env python3
+# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script computes durations of datasets from
+the SupervisionSet manifests.
+
+Usage example:
+
+ python3 ./local/duration_from_supervision_manifest.py \
+ data/manifest/*_superivions*.jsonl.gz
+"""
+
+import argparse
+import gzip
+import json
+import logging
+import re
+import sys
+
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ "Read the raw text from the 'supervisions.jsonl.gz'"
+ )
+
+ parser.add_argument(
+ "filename",
+ help="supervisions.jsonl.gz",
+ nargs="+",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+ logging.info(vars(args))
+
+ total_duration = 0.0
+ total_n_utts = 0
+
+ for fname in args.filename:
+ if fname == "-":
+ fd = sys.stdin
+ elif re.match(r".*\.jsonl\.gz$", fname):
+ fd = gzip.open(fname, mode="r")
+ else:
+ fd = open(fname, mode="r")
+
+ fname_duration = 0.0
+ n_utts = 0
+ for line in fd:
+ js = json.loads(line)
+ fname_duration += js["duration"]
+ n_utts += 1
+
+ print(
+ f"Duration: {fname_duration/3600:7.2f} hours "
+ f"(eq. {fname_duration:7.0f} seconds, {n_utts} utts): {fname}"
+ )
+
+ if fd != sys.stdin:
+ fd.close()
+
+ total_duration += fname_duration
+ total_n_utts += n_utts
+
+ print(
+ f"Total duration: {total_duration/3600:7.2f} hours "
+ f"(eq. {total_duration:7.0f} seconds)"
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/voxpopuli/ASR/local/filter_cuts.py b/egs/voxpopuli/ASR/local/filter_cuts.py
new file mode 120000
index 000000000..27aca1729
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/filter_cuts.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/filter_cuts.py
\ No newline at end of file
diff --git a/egs/voxpopuli/ASR/local/prepare_lang_bpe.py b/egs/voxpopuli/ASR/local/prepare_lang_bpe.py
new file mode 120000
index 000000000..36b40e7fc
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/prepare_lang_bpe.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/prepare_lang_bpe.py
\ No newline at end of file
diff --git a/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py b/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py
new file mode 100755
index 000000000..4032537db
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py
@@ -0,0 +1,178 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang)
+# 2023 Brno University of Technology (author: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Preprocess the database.
+- Convert RecordingSet and SupervisionSet to CutSet.
+- Apply text normalization to the transcripts.
+ - We take renormalized `orig_text` as `text` transcripts.
+ - The text normalization is separating punctuation from words.
+ - Also we put capital letter to the beginning of a sentence.
+
+The script is inspired in:
+ `egs/commonvoice/ASR/local/preprocess_commonvoice.py`
+
+Usage example:
+ python3 ./local/preprocess_voxpopuli.py \
+ --task asr --lang en
+
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Optional
+
+from lhotse import CutSet
+from lhotse.recipes.utils import read_manifests_if_cached
+
+# from local/
+from separate_punctuation import separate_punctuation
+from uppercase_begin_of_sentence import UpperCaseBeginOfSentence
+
+from icefall.utils import str2bool
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ help="""Dataset parts to compute fbank. If None, we will use all""",
+ default=None,
+ )
+
+ parser.add_argument(
+ "--task",
+ type=str,
+ help="""Task of VoxPopuli""",
+ default="asr",
+ )
+
+ parser.add_argument(
+ "--lang",
+ type=str,
+ help="""Language of VoxPopuli""",
+ required=True,
+ )
+
+ parser.add_argument(
+ "--use-original-text",
+ type=str2bool,
+ help="""Use 'original_text' from the annoattaion file,
+ otherwise 'normed_text' will be used
+ (see `data/manifests/${task}_${lang}.tsv.gz`).
+ """,
+ default=False,
+ )
+
+ return parser.parse_args()
+
+
+def normalize_text(utt: str) -> str:
+ utt = UpperCaseBeginOfSentence().process_line_text(separate_punctuation(utt))
+ return utt
+
+
+def preprocess_voxpopuli(
+ task: str,
+ language: str,
+ dataset: Optional[str] = None,
+ use_original_text: bool = False,
+):
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+ output_dir.mkdir(exist_ok=True)
+
+ if dataset is None:
+ dataset_parts = (
+ "dev",
+ "test",
+ "train",
+ )
+ else:
+ dataset_parts = dataset.split(" ", -1)
+
+ logging.info("Loading manifest")
+ prefix = f"voxpopuli-{task}-{language}"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ suffix=suffix,
+ prefix=prefix,
+ )
+ assert manifests is not None
+
+ assert len(manifests) == len(dataset_parts), (
+ len(manifests),
+ len(dataset_parts),
+ list(manifests.keys()),
+ dataset_parts,
+ )
+
+ for partition, m in manifests.items():
+ logging.info(f"Processing {partition}")
+ raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}"
+ if raw_cuts_path.is_file():
+ logging.info(f"{partition} already exists - skipping")
+ continue
+
+ if use_original_text:
+ logging.info("Using 'original_text' from the annotation file.")
+ logging.info(f"Normalizing text in {partition}")
+ for sup in m["supervisions"]:
+ # `orig_text` includes punctuation and true-case
+ orig_text = str(sup.custom["orig_text"])
+ # we replace `text` by normalized `orig_text`
+ sup.text = normalize_text(orig_text)
+ else:
+ logging.info("Using 'normed_text' from the annotation file.")
+
+ # remove supervisions with empty 'text'
+ m["supervisions"] = m["supervisions"].filter(lambda sup: len(sup.text) > 0)
+
+ # Create cut manifest with long-recordings.
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ ).resample(16000)
+
+ # Store the cut set incl. the resampling.
+ logging.info(f"Saving to {raw_cuts_path}")
+ cut_set.to_file(raw_cuts_path)
+
+
+def main():
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ args = get_args()
+ logging.info(vars(args))
+ preprocess_voxpopuli(
+ task=args.task,
+ language=args.lang,
+ dataset=args.dataset,
+ use_original_text=args.use_original_text,
+ )
+ logging.info("Done")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/voxpopuli/ASR/local/separate_punctuation.py b/egs/voxpopuli/ASR/local/separate_punctuation.py
new file mode 100755
index 000000000..706d6fcd5
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/separate_punctuation.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script chops the punctuation as standalone tokens.
+Example:
+ input: "This is fine. Yes, you are right."
+ output: "This is fine . Yes , you are right ."
+
+The script also handles exceptions in a hard-coded fashion.
+
+(same functionality could be done with `nltk.tokenize.word_tokenize()`,
+ but that would be an extra dependency)
+
+It can be used as a module, or as an executable script.
+
+Usage example #1:
+ `from separate_punctuation import separate_punctuation`
+
+Usage example #2:
+```
+ python3 ./local/separate_punctuation.py \
+ --ignore-columns 1 \
+ < ${kaldi_data}/text
+```
+"""
+
+import re
+import sys
+from argparse import ArgumentParser
+
+
+def separate_punctuation(text: str) -> str:
+ """
+ Text filtering function for separating punctuation.
+
+ Example:
+ input: "This is fine. Yes, you are right."
+ output: "This is fine . Yes , you are right ."
+
+ The exceptions for which the punctuation is
+ not splitted are hard-coded.
+ """
+
+ # remove non-desired punctuation symbols
+ text = re.sub('["„“«»]', "", text)
+
+ # separate [,.!?;] punctuation from words by space
+ text = re.sub(r"(\w)([,.!?;])", r"\1 \2", text)
+ text = re.sub(r"([,.!?;])(\w)", r"\1 \2", text)
+
+ # split to tokens
+ tokens = text.split()
+ tokens_out = []
+
+ # re-join the special cases of punctuation
+ for ii, tok in enumerate(tokens):
+ # no rewriting for 1st and last token
+ if ii > 0 and ii < len(tokens) - 1:
+ # **RULES ADDED FOR CZECH COMMON VOICE**
+
+ # fix "27 . dubna" -> "27. dubna", but keep punctuation separate,
+ if tok == "." and tokens[ii - 1].isdigit() and tokens[ii + 1].islower():
+ tokens_out[-1] = tokens_out[-1] + "."
+ continue
+
+ # fix "resp . pak" -> "resp. pak"
+ if tok == "." and tokens[ii - 1].isalpha() and tokens[ii + 1].islower():
+ tokens_out[-1] = tokens_out[-1] + "."
+ continue
+
+ # **RULES ADDED FOR ENGLISH COMMON VOICE**
+
+ # fix "A ." -> "A."
+ if tok == "." and re.match(r"^[A-Z]S", tokens[ii - 1]):
+ tokens_out[-1] = tokens_out[-1] + "."
+ continue
+
+ # fix "Mr ." -> "Mr."
+ exceptions = set(["Mr", "Mrs", "Ms"])
+ if tok == "." and tokens[ii - 1] in exceptions:
+ tokens_out[-1] = tokens_out[-1] + "."
+ continue
+
+ tokens_out.append(tok)
+
+ return " ".join(tokens_out)
+
+
+def get_args():
+ parser = ArgumentParser(
+ description="Separate punctuation from words: 'hello.' -> 'hello .'"
+ )
+ parser.add_argument(
+ "--ignore-columns", type=int, default=1, help="skip number of initial columns"
+ )
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ max_split = args.ignore_columns
+
+ while True:
+ line = sys.stdin.readline()
+ if not line:
+ break
+
+ *key, text = line.strip().split(maxsplit=max_split)
+ text_norm = separate_punctuation(text)
+
+ print(" ".join(key), text_norm)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/voxpopuli/ASR/local/text_from_manifest.py b/egs/voxpopuli/ASR/local/text_from_manifest.py
new file mode 100755
index 000000000..d9ab53b5a
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/text_from_manifest.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python3
+# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Print the text contained in `supervisions.jsonl.gz` or `cuts.jsonl.gz`.
+
+Usage example:
+ python3 ./local/text_from_manifest.py \
+ data/manifests/voxpopuli-asr-en_supervisions_dev.jsonl.gz
+"""
+
+import argparse
+import gzip
+import json
+
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ "Read the raw text from the 'supervisions.jsonl.gz'"
+ )
+ parser.add_argument("filename", help="supervisions.jsonl.gz")
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ with gzip.open(args.filename, mode="r") as fd:
+ for line in fd:
+ js = json.loads(line)
+ if "text" in js:
+ print(js["text"]) # supervisions.jsonl.gz
+ elif "supervisions" in js:
+ for s in js["supervisions"]:
+ print(s["text"]) # cuts.jsonl.gz
+ else:
+ raise Exception(f"Unknown jsonl format of {args.filename}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/voxpopuli/ASR/local/train_bpe_model.py b/egs/voxpopuli/ASR/local/train_bpe_model.py
new file mode 120000
index 000000000..6fad36421
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/train_bpe_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/train_bpe_model.py
\ No newline at end of file
diff --git a/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py b/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py
new file mode 100755
index 000000000..8e9de905f
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py
@@ -0,0 +1,113 @@
+#!/usr/bin/env python3
+# Copyright 2023 Brno University of Technology (authors: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script introduces initial capital letter at the beginning of a sentence.
+It can be used as a module, or as an executable script.
+
+Usage example #1:
+ `from uppercase_begin_of_sentence import UpperCaseBeginOfSentence`
+
+Usage example #2:
+```
+ python3 ./local/uppercase_begin_of_sentence.py \
+ --ignore-columns 1 \
+ < ${kaldi_data}/text
+```
+"""
+
+import re
+import sys
+from argparse import ArgumentParser
+
+
+class UpperCaseBeginOfSentence:
+ """
+ This class introduces initial capital letter at the beginning of a sentence.
+ Capital letter is used, if previous symbol was punctuation token from
+ `set([".", "!", "?"])`.
+
+ The punctuation as previous token is memorized also across
+ `process_line_text()` calls.
+ """
+
+ def __init__(self):
+ # The 1st word will have Title-case
+ # This variable transfers context from previous line
+ self.prev_token_is_punct = True
+
+ def process_line_text(self, line_text: str) -> str:
+ """
+ It is assumed that punctuation in `line_text` was already separated,
+ example: "This is fine . Yes , you are right ."
+ """
+
+ words = line_text.split()
+ punct_set = set([".", "!", "?"])
+
+ for ii, w in enumerate(words):
+ # punctuation ?
+ if w in punct_set:
+ self.prev_token_is_punct = True
+ continue
+
+ # change case of word...
+ if self.prev_token_is_punct:
+ if re.match("<", w):
+ continue # skip
+ # apply Title-case only on lowercase words.
+ if w.islower():
+ words[ii] = w.title()
+ # change state
+ self.prev_token_is_punct = False
+
+ line_text_uc = " ".join(words)
+
+ return line_text_uc
+
+
+def get_args():
+ parser = ArgumentParser(
+ description="Put upper-case at the beginning of a sentence."
+ )
+ parser.add_argument(
+ "--ignore-columns", type=int, default=4, help="skip number of initial columns"
+ )
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ uc_bos = UpperCaseBeginOfSentence()
+ max_split = args.ignore_columns
+
+ while True:
+ line = sys.stdin.readline()
+ if not line:
+ break
+ line = line.strip()
+
+ if len(line.split()) > 1:
+ *key, text = line.strip().split(maxsplit=max_split) # parse,
+ text_uc = uc_bos.process_line_text(text) # process,
+ print(" ".join(key), text_uc) # print,
+ else:
+ print(line)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py b/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py
new file mode 120000
index 000000000..721bb48e7
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/local/validate_bpe_lexicon.py
\ No newline at end of file
diff --git a/egs/voxpopuli/ASR/local/validate_cutset_manifest.py b/egs/voxpopuli/ASR/local/validate_cutset_manifest.py
new file mode 100755
index 000000000..4659aa9cd
--- /dev/null
+++ b/egs/voxpopuli/ASR/local/validate_cutset_manifest.py
@@ -0,0 +1,123 @@
+#!/usr/bin/env python3
+# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
+# 2023 Brno University of Technology (authors: Karel Veselý)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script checks the following assumptions of the generated manifest:
+
+- Single supervision per cut
+- Supervision time bounds are within Cut time bounds
+- Duration of Cut and Superivion are equal
+
+We will add more checks later if needed.
+
+Usage example:
+
+ python3 ./local/validate_manifest.py \
+ ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz
+
+(Based on: `librispeech/ASR/local/validate_manifest.py`)
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+from lhotse import CutSet, load_manifest_lazy
+from lhotse.cut import Cut
+from lhotse.dataset.speech_recognition import validate_for_asr
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "cutset_manifest",
+ type=Path,
+ help="Path to the manifest file",
+ )
+
+ return parser.parse_args()
+
+
+def validate_one_supervision_per_cut(c: Cut):
+ if len(c.supervisions) != 1:
+ raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions")
+
+
+def validate_supervision_and_cut_time_bounds(c: Cut):
+ tol = 2e-3 # same tolerance as in 'validate_for_asr()'
+ s = c.supervisions[0]
+
+ # Supervision start time is relative to Cut ...
+ # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
+ if s.start < -tol:
+ raise ValueError(
+ f"{c.id}: Supervision start time {s.start} must not be negative."
+ )
+ if s.start > tol:
+ raise ValueError(
+ f"{c.id}: Supervision start time {s.start} "
+ "is not at the beginning of the Cut. "
+ "Please apply `lhotse cut trim-to-supervisions`."
+ )
+ if c.start + s.end > c.end + tol:
+ raise ValueError(
+ f"{c.id}: Supervision end time {c.start+s.end} is larger "
+ f"than cut end time {c.end}"
+ )
+
+ if s.duration != c.duration:
+ raise ValueError(
+ f"{c.id}: Cut duration {c.duration} and supervision duration "
+ f"{s.duration} must be the same.\n"
+ f"The difference causes problems in the training code : "
+ f"+/- 1 frame in `x`, `x_lens` in `Zipformer::forward()`.\n"
+ f"Did you forget to apply `trim_to_supervisions()` ?"
+ )
+
+
+def main():
+ args = get_args()
+
+ manifest = args.cutset_manifest
+ logging.info(f"Validating {manifest}")
+
+ assert manifest.is_file(), f"{manifest} does not exist"
+ cut_set = load_manifest_lazy(manifest)
+ assert isinstance(cut_set, CutSet)
+
+ try:
+ for c in cut_set:
+ validate_one_supervision_per_cut(c)
+ validate_supervision_and_cut_time_bounds(c)
+
+ # Validation from K2 training
+ # - checks supervision start is 0
+ # - checks supervision.duration is not longer than cut.duration
+ # - there is tolerance 2ms
+ validate_for_asr(cut_set)
+ except BaseException as e:
+ logging.error(str(e))
+ raise
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/voxpopuli/ASR/prepare.sh b/egs/voxpopuli/ASR/prepare.sh
new file mode 100755
index 000000000..7cddad756
--- /dev/null
+++ b/egs/voxpopuli/ASR/prepare.sh
@@ -0,0 +1,257 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -euxo pipefail
+
+nj=20
+stage=-1
+stop_stage=100
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/voxpopuli/raw_audios/$lang/$year
+# This directory contains *.ogg files with audio downloaded and extracted from archives:
+# https://dl.fbaipublicfiles.com/voxpopuli/audios/${lang}_${year}.tar
+#
+# - Note: the voxpopuli transcripts are downloaded to a ${tmp} folder
+# as part of `lhotse prepare voxpopuli` from:
+# https://dl.fbaipublicfiles.com/voxpopuli/annotations/asr/asr_${lang}.tsv.gz
+#
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+
+dl_dir=$PWD/download
+#dl_dir=/mnt/matylda6/szoke/EU-ASR/DATA # BUT
+
+musan_dir=${dl_dir}/musan
+#musan_dir=/mnt/matylda2/data/MUSAN # BUT
+
+# Choose value from ASR_LANGUAGES:
+#
+# [ "en", "de", "fr", "es", "pl", "it", "ro", "hu", "cs", "nl", "fi", "hr",
+# "sk", "sl", "et", "lt" ]
+#
+# See ASR_LANGUAGES in:
+# https://github.com/lhotse-speech/lhotse/blob/c5f26afd100885b86e4244eeb33ca1986f3fa923/lhotse/recipes/voxpopuli.py#L54C4-L54C4
+lang=en
+
+task=asr
+
+. shared/parse_options.sh || exit 1
+
+# vocab size for sentence piece models.
+# It will generate data/${lang}/lang_bpe_xxx,
+# data/${lang}/lang_bpe_yyy if the array contains xxx, yyy
+vocab_sizes=(
+ # 5000
+ # 2000
+ # 1000
+ 500
+)
+
+# All files generated by this script are saved in "data/${lang}".
+# You can safely remove "data/${lang}" and rerun this script to regenerate it.
+mkdir -p data/${lang}
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+log "musan_dir: $musan_dir"
+log "task: $task, lang: $lang"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ # If you have pre-downloaded it to /path/to/$release,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/$release $dl_dir/$release
+ #
+ if [ ! -d $dl_dir/voxpopuli/raw_audios/${lang} ]; then
+ lhotse download voxpopuli --subset $lang $dl_dir/voxpopuli
+ fi
+
+ # If you have pre-downloaded it to /path/to/musan,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/musan $dl_dir/
+ #
+ if [ ! -d $musan_dir/musan ]; then
+ lhotse download musan $musan_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare VoxPopuli manifest"
+ # We assume that you have downloaded the VoxPopuli corpus
+ # to $dl_dir/voxpopuli
+ if [ ! -e data/manifests/.voxpopuli-${task}-${lang}.done ]; then
+ # Warning : it requires Internet connection (it downloads transcripts to ${tmpdir})
+ lhotse prepare voxpopuli --task asr --lang $lang -j $nj $dl_dir/voxpopuli data/manifests
+ touch data/manifests/.voxpopuli-${task}-${lang}.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Prepare musan manifest"
+ # We assume that you have downloaded the musan corpus
+ # to data/musan
+ mkdir -p data/manifests
+ if [ ! -e data/manifests/.musan.done ]; then
+ #lhotse prepare musan $dl_dir/musan data/manifests
+ lhotse prepare musan $musan_dir/musan data/manifests
+ touch data/manifests/.musan.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Preprocess VoxPopuli manifest"
+ mkdir -p data/fbank
+ if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-preprocess_complete ]; then
+ # recordings + supervisions -> cutset
+ ./local/preprocess_voxpopuli.py --task $task --lang $lang \
+ --use-original-text True
+ touch data/fbank/.voxpopuli-${task}-${lang}-preprocess_complete
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Compute fbank for dev and test subsets of VoxPopuli"
+ mkdir -p data/fbank
+ for dataset in "dev" "test"; do
+ if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-${dataset}.done ]; then
+ ./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \
+ --num-jobs 50 --num-workers ${nj} \
+ --prefix "voxpopuli-${task}-${lang}" \
+ --dataset ${dataset} \
+ --trim-to-supervisions True
+ touch data/fbank/.voxpopuli-${task}-${lang}-${dataset}.done
+ fi
+ done
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Compute fbank for train set of VoxPopuli"
+ if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-train.done ]; then
+ ./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \
+ --num-jobs 100 --num-workers ${nj} \
+ --prefix "voxpopuli-${task}-${lang}" \
+ --dataset train \
+ --trim-to-supervisions True \
+ --speed-perturb True
+ touch data/fbank/.voxpopuli-${task}-${lang}-train.done
+ fi
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Validate fbank manifests for VoxPopuli"
+ for dataset in "dev" "test" "train"; do
+ mkdir -p data/fbank/log/
+ ./local/validate_cutset_manifest.py \
+ data/fbank/voxpopuli-asr-en_cuts_${dataset}.jsonl.gz \
+ 2>&1 | tee data/fbank/log/validate_voxpopuli-asr-en_cuts_${dataset}.log
+ done
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+ log "Stage 7: Compute fbank for musan"
+ mkdir -p data/fbank
+ if [ ! -e data/fbank/.musan.done ]; then
+ ./local/compute_fbank_musan.py
+ touch data/fbank/.musan.done
+ fi
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+ log "Stage 8: Prepare BPE based lang"
+
+ for vocab_size in ${vocab_sizes[@]}; do
+ lang_dir=data/lang_bpe_${vocab_size}_${lang}
+ mkdir -p $lang_dir
+
+ if [ ! -f $lang_dir/transcript_words.txt ]; then
+ log "Generate data for BPE training"
+ file=$(
+ find "data/fbank/voxpopuli-${task}-${lang}_cuts_train.jsonl.gz"
+ )
+ local/text_from_manifest.py $file >$lang_dir/transcript_words.txt
+ # gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt
+
+ # Ensure space only appears once
+ #sed -i 's/\t/ /g' $lang_dir/transcript_words.txt
+ #sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt
+ fi
+
+ if [ ! -f $lang_dir/words.txt ]; then
+ cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \
+ | sort -u | sed '/^$/d' > $lang_dir/words.txt
+ (echo '!SIL'; echo ''; echo ''; ) |
+ cat - $lang_dir/words.txt | sort | uniq | awk '
+ BEGIN {
+ print " 0";
+ }
+ {
+ if ($1 == "") {
+ print " is in the vocabulary!" | "cat 1>&2"
+ exit 1;
+ }
+ if ($1 == "") {
+ print " is in the vocabulary!" | "cat 1>&2"
+ exit 1;
+ }
+ printf("%s %d\n", $1, NR);
+ }
+ END {
+ printf("#0 %d\n", NR+1);
+ printf(" %d\n", NR+2);
+ printf(" %d\n", NR+3);
+ }' > $lang_dir/words || exit 1;
+ mv $lang_dir/words $lang_dir/words.txt
+ fi
+
+ if [ ! -f $lang_dir/bpe.model ]; then
+ ./local/train_bpe_model.py \
+ --lang-dir $lang_dir \
+ --vocab-size $vocab_size \
+ --transcript $lang_dir/transcript_words.txt
+ fi
+
+ if [ ! -f $lang_dir/L_disambig.pt ]; then
+ ./local/prepare_lang_bpe.py --lang-dir $lang_dir
+
+ log "Validating $lang_dir/lexicon.txt"
+ ./local/validate_bpe_lexicon.py \
+ --lexicon $lang_dir/lexicon.txt \
+ --bpe-model $lang_dir/bpe.model
+ fi
+
+ if [ ! -f $lang_dir/L.fst ]; then
+ log "Converting L.pt to L.fst"
+ ./shared/convert-k2-to-openfst.py \
+ --olabels aux_labels \
+ $lang_dir/L.pt \
+ $lang_dir/L.fst
+ fi
+
+ if [ ! -f $lang_dir/L_disambig.fst ]; then
+ log "Converting L_disambig.pt to L_disambig.fst"
+ ./shared/convert-k2-to-openfst.py \
+ --olabels aux_labels \
+ $lang_dir/L_disambig.pt \
+ $lang_dir/L_disambig.fst
+ fi
+ done
+fi
diff --git a/egs/voxpopuli/ASR/shared b/egs/voxpopuli/ASR/shared
new file mode 120000
index 000000000..4c5e91438
--- /dev/null
+++ b/egs/voxpopuli/ASR/shared
@@ -0,0 +1 @@
+../../../icefall/shared/
\ No newline at end of file
diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
index 36b8a4b67..d665f3364 100755
--- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
@@ -868,7 +868,7 @@ def main():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
- context_graph.build(contexts)
+ context_graph.build([(c, 0.0) for c in contexts])
else:
context_graph = None
else:
diff --git a/icefall/context_graph.py b/icefall/context_graph.py
index 0b7c42c0b..b3d7972a8 100644
--- a/icefall/context_graph.py
+++ b/icefall/context_graph.py
@@ -84,6 +84,9 @@ class ContextGraph:
context_score:
The bonus score for each token(note: NOT for each word/phrase, it means longer
word/phrase will have larger bonus score, they have to be matched though).
+ Note: This is just the default score for each token, the users can manually
+ specify the context_score for each word/phrase (i.e. different phrase might
+ have different token score).
"""
self.context_score = context_score
self.num_nodes = 0
@@ -133,7 +136,7 @@ class ContextGraph:
node.output_score += 0 if output is None else output.output_score
queue.append(node)
- def build(self, token_ids: List[List[int]]):
+ def build(self, token_ids: List[Tuple[List[int], float]]):
"""Build the ContextGraph from a list of token list.
It first build a trie from the given token lists, then fill the fail arc
for each trie node.
@@ -142,26 +145,46 @@ class ContextGraph:
Args:
token_ids:
- The given token lists to build the ContextGraph, it is a list of token list,
- each token list contains the token ids for a word/phrase. The token id
- could be an id of a char (modeling with single Chinese char) or an id
- of a BPE (modeling with BPEs).
+ The given token lists to build the ContextGraph, it is a list of tuple of
+ token list and its customized score, the token list contains the token ids
+ for a word/phrase. The token id could be an id of a char
+ (modeling with single Chinese char) or an id of a BPE
+ (modeling with BPEs). The score is the total score for current token list,
+ 0 means using the default value (i.e. self.context_score).
+
+ Note: The phrases would have shared states, the score of the shared states is
+ the maximum value among all the tokens sharing this state.
"""
- for tokens in token_ids:
+ for (tokens, score) in token_ids:
node = self.root
+ # If has customized score using the customized token score, otherwise
+ # using the default score
+ context_score = (
+ self.context_score if score == 0.0 else round(score / len(tokens), 2)
+ )
for i, token in enumerate(tokens):
+ node_next = {}
if token not in node.next:
self.num_nodes += 1
+ node_id = self.num_nodes
+ token_score = context_score
is_end = i == len(tokens) - 1
- node_score = node.node_score + self.context_score
- node.next[token] = ContextState(
- id=self.num_nodes,
- token=token,
- token_score=self.context_score,
- node_score=node_score,
- output_score=node_score if is_end else 0,
- is_end=is_end,
- )
+ else:
+ # node exists, get the score of shared state.
+ token_score = max(context_score, node.next[token].token_score)
+ node_id = node.next[token].id
+ node_next = node.next[token].next
+ is_end = i == len(tokens) - 1 or node.next[token].is_end
+ node_score = node.node_score + token_score
+ node.next[token] = ContextState(
+ id=node_id,
+ token=token,
+ token_score=token_score,
+ node_score=node_score,
+ output_score=node_score if is_end else 0,
+ is_end=is_end,
+ )
+ node.next[token].next = node_next
node = node.next[token]
self._fill_fail_output()
@@ -343,7 +366,7 @@ class ContextGraph:
return dot
-if __name__ == "__main__":
+def _test(queries, score):
contexts_str = [
"S",
"HE",
@@ -355,9 +378,11 @@ if __name__ == "__main__":
"THIS",
"THEM",
]
+
+ # test default score (1)
contexts = []
for s in contexts_str:
- contexts.append([ord(x) for x in s])
+ contexts.append(([ord(x) for x in s], score))
context_graph = ContextGraph(context_score=1)
context_graph.build(contexts)
@@ -369,10 +394,28 @@ if __name__ == "__main__":
context_graph.draw(
title="Graph for: " + " / ".join(contexts_str),
- filename="context_graph.pdf",
+ filename=f"context_graph_{score}.pdf",
symbol_table=symbol_table,
)
+ for query, expected_score in queries.items():
+ total_scores = 0
+ state = context_graph.root
+ for q in query:
+ score, state = context_graph.forward_one_step(state, ord(q))
+ total_scores += score
+ score, state = context_graph.finalize(state)
+ assert state.token == -1, state.token
+ total_scores += score
+ assert round(total_scores, 2) == expected_score, (
+ total_scores,
+ expected_score,
+ query,
+ )
+
+
+if __name__ == "__main__":
+ # test default score
queries = {
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
@@ -384,17 +427,27 @@ if __name__ == "__main__":
"DHRHISQ": 4, # "HIS", "S"
"THEN": 2, # "HE"
}
- for query, expected_score in queries.items():
- total_scores = 0
- state = context_graph.root
- for q in query:
- score, state = context_graph.forward_one_step(state, ord(q))
- total_scores += score
- score, state = context_graph.finalize(state)
- assert state.token == -1, state.token
- total_scores += score
- assert total_scores == expected_score, (
- total_scores,
- expected_score,
- query,
- )
+ _test(queries, 0)
+
+ # test custom score (5)
+ # S : 5
+ # HE : 5 (2.5 + 2.5)
+ # SHE : 8.34 (5 + 1.67 + 1.67)
+ # SHELL : 10.34 (5 + 1.67 + 1.67 + 1 + 1)
+ # HIS : 5.84 (2.5 + 1.67 + 1.67)
+ # HERS : 7.5 (2.5 + 2.5 + 1.25 + 1.25)
+ # HELLO : 8 (2.5 + 2.5 + 1 + 1 + 1)
+ # THIS : 5 (1.25 + 1.25 + 1.25 + 1.25)
+ queries = {
+ "HEHERSHE": 35.84, # "HE", "HE", "HERS", "S", "SHE", "HE"
+ "HERSHE": 30.84, # "HE", "HERS", "S", "SHE", "HE"
+ "HISHE": 24.18, # "HIS", "S", "SHE", "HE"
+ "SHED": 18.34, # "S", "SHE", "HE"
+ "SHELF": 18.34, # "S", "SHE", "HE"
+ "HELL": 5, # "HE"
+ "HELLO": 13, # "HE", "HELLO"
+ "DHRHISQ": 10.84, # "HIS", "S"
+ "THEN": 5, # "HE"
+ }
+
+ _test(queries, 5)
diff --git a/pyproject.toml b/pyproject.toml
index c40143fb9..435256416 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,4 +14,5 @@ exclude = '''
| icefall\/diagnostics\.py
| icefall\/profiler\.py
| egs\/librispeech\/ASR\/zipformer
+ | egs\/ljspeech\/TTS\/vits
'''
diff --git a/requirements-ci.txt b/requirements-ci.txt
index e1232a768..6c74f688c 100644
--- a/requirements-ci.txt
+++ b/requirements-ci.txt
@@ -17,6 +17,7 @@ six
git+https://github.com/lhotse-speech/lhotse
kaldilm==1.11
kaldialign==0.7.1
+num2words
sentencepiece==0.1.96
tensorboard==2.8.0
typeguard==2.13.3
diff --git a/requirements.txt b/requirements.txt
index 5a8326619..9502fcbd2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,7 @@
kaldifst
kaldilm
kaldialign
+num2words
kaldi-decoder
sentencepiece>=0.1.96
tensorboard