diff --git a/egs/zipvoice/README.md b/egs/zipvoice/README.md
new file mode 100644
index 000000000..0eed8f540
--- /dev/null
+++ b/egs/zipvoice/README.md
@@ -0,0 +1,360 @@
+## ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching
+
+
+[](http://arxiv.org/abs/2506.13053)
+[](https://zipvoice.github.io/)
+
+
+## Overview
+ZipVoice is a high-quality zero-shot TTS model with a small model size and fast inference speed.
+#### Key features:
+
+- Small and fast: only 123M parameters.
+
+- High-quality: state-of-the-art voice cloning performance in speaker similarity, intelligibility, and naturalness.
+
+- Multi-lingual: support Chinese and English.
+
+
+## News
+**2025/06/16**: 🔥 ZipVoice is released.
+
+
+## Installation
+```
+pip install -r requirements.txt
+```
+
+## Usage
+
+To generate speech with our pre-trained ZipVoice or ZipVoice-Distill models, use the following commands (Required models will be downloaded from HuggingFace):
+
+### 1. Inference of a single sentence:
+```bash
+python3 zipvoice/zipvoice_infer.py \
+ --model-name "zipvoice_distill" \
+ --prompt-wav prompt.wav \
+ --prompt-text "I am the transcription of the prompt wav." \
+ --text "I am the text to be synthesized." \
+ --res-wav-path result.wav
+```
+
+### 2. Inference of a list of sentences:
+```bash
+python3 zipvoice/zipvoice_infer.py \
+ --model-name "zipvoice_distill" \
+ --test-list test.tsv \
+ --res-dir results/test
+```
+- `--model-name` can be `zipvoice` or `zipvoice_distill`, which are models before and after distillation, respectively.
+- Each line of `test.tsv` is in the format of `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
+
+
+> **Note:** If you having trouble connecting to HuggingFace, try:
+```bash
+export HF_ENDPOINT=https://hf-mirror.com
+```
+
+## Training Your Own Model
+
+The following steps show how to train a model from scratch on Emilia and LibriTTS datasets, respectively.
+
+### 1. Data Preparation
+
+#### 1.1. Prepare the Emilia dataset
+
+#### 1.2 Prepare the LibriTTS dataset
+
+See [local/prepare_libritts.sh](local/prepare_libritts.sh)
+
+### 2. Training
+
+#### 2.1 Traininig on Emilia
+
+
+Expand to view training steps
+
+##### 2.1.1 Train the ZipVoice model
+
+- Training:
+
+```bash
+export PYTHONPATH=../../:$PYTHONPATH
+python3 zipvoice/train_flow.py \
+ --world-size 8 \
+ --use-fp16 1 \
+ --dataset emilia \
+ --max-duration 500 \
+ --lr-hours 30000 \
+ --lr-batches 7500 \
+ --token-file "data/tokens_emilia.txt" \
+ --manifest-dir "data/fbank_emilia" \
+ --num-epochs 11 \
+ --exp-dir zipvoice/exp_zipvoice
+```
+
+- Average the checkpoints to produce the final model:
+
+```bash
+export PYTHONPATH=../../:$PYTHONPATH
+python3 zipvoice/generate_averaged_model.py \
+ --epoch 11 \
+ --avg 4 \
+ --distill 0 \
+ --token-file data/tokens_emilia.txt \
+ --dataset "emilia" \
+ --exp-dir ./zipvoice/exp_zipvoice
+# The generated model is zipvoice/exp_zipvoice/epoch-11-avg-4.pt
+```
+
+##### 2.1.2. Train the ZipVoice-Distill model (Optional)
+
+- The first-stage distillation:
+
+```bash
+export PYTHONPATH=../../:$PYTHONPATH
+python3 zipvoice/train_distill.py \
+ --world-size 8 \
+ --use-fp16 1 \
+ --tensorboard 1 \
+ --dataset "emilia" \
+ --base-lr 0.0005 \
+ --max-duration 500 \
+ --token-file "data/tokens_emilia.txt" \
+ --manifest-dir "data/fbank_emilia" \
+ --teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
+ --num-updates 60000 \
+ --distill-stage "first" \
+ --exp-dir zipvoice/exp_zipvoice_distill_1stage
+```
+
+- Average checkpoints for the second-stage initialization:
+
+```bash
+export PYTHONPATH=../../:$PYTHONPATH
+python3 zipvoice/generate_averaged_model.py \
+ --iter 60000 \
+ --avg 7 \
+ --distill 1 \
+ --token-file data/tokens_emilia.txt \
+ --dataset "emilia" \
+ --exp-dir ./zipvoice/exp_zipvoice_distill_1stage
+# The generated model is zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt
+```
+
+- The second-stage distillation:
+
+```bash
+export PYTHONPATH=../../:$PYTHONPATH
+python3 zipvoice/train_distill.py \
+ --world-size 8 \
+ --use-fp16 1 \
+ --tensorboard 1 \
+ --dataset "emilia" \
+ --base-lr 0.0001 \
+ --max-duration 200 \
+ --token-file "data/tokens_emilia.txt" \
+ --manifest-dir "data/fbank_emilia" \
+ --teacher-model zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt \
+ --num-updates 2000 \
+ --distill-stage "second" \
+ --exp-dir zipvoice/exp_zipvoice_distill_new
+```
+
+
+
+#### 2.2 Traininig on LibriTTS
+
+
+Expand to view training steps
+
+##### 2.2.1 Train the ZipVoice model
+
+- Training:
+
+```bash
+export PYTHONPATH=../../:$PYTHONPATH
+python3 zipvoice/train_flow.py \
+ --world-size 8 \
+ --use-fp16 1 \
+ --dataset libritts \
+ --max-duration 250 \
+ --lr-epochs 10 \
+ --lr-batches 7500 \
+ --token-file "data/tokens_libritts.txt" \
+ --manifest-dir "data/fbank_libritts" \
+ --num-epochs 60 \
+ --exp-dir zipvoice/exp_zipvoice_libritts
+```
+
+- Average the checkpoints to produce the final model:
+
+```bash
+export PYTHONPATH=../../:$PYTHONPATH
+python3 zipvoice/generate_averaged_model.py \
+ --epoch 60 \
+ --avg 10 \
+ --distill 0 \
+ --token-file data/tokens_libritts.txt \
+ --dataset "libritts" \
+ --exp-dir ./zipvoice/exp_zipvoice_libritts
+# The generated model is zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt
+```
+
+##### 2.1.2 Train the ZipVoice-Distill model (Optional)
+
+- The first-stage distillation:
+
+```bash
+export PYTHONPATH=../../:$PYTHONPATH
+python3 zipvoice/train_distill.py \
+ --world-size 8 \
+ --use-fp16 1 \
+ --tensorboard 1 \
+ --dataset "libritts" \
+ --base-lr 0.001 \
+ --max-duration 250 \
+ --token-file "data/tokens_libritts.txt" \
+ --manifest-dir "data/fbank_libritts" \
+ --teacher-model zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt \
+ --num-epochs 6 \
+ --distill-stage "first" \
+ --exp-dir zipvoice/exp_zipvoice_distill_1stage_libritts
+```
+
+- Average checkpoints for the second-stage initialization:
+
+```bash
+export PYTHONPATH=../../:$PYTHONPATH
+python3 ./zipvoice/generate_averaged_model.py \
+ --epoch 6 \
+ --avg 3 \
+ --distill 1 \
+ --token-file data/tokens_libritts.txt \
+ --dataset "libritts" \
+ --exp-dir ./zipvoice/exp_zipvoice_distill_1stage_libritts
+# The generated model is zipvoice/exp_zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt
+```
+
+- The second-stage distillation:
+
+```bash
+export PYTHONPATH=../../:$PYTHONPATH
+python3 zipvoice/train_distill.py \
+ --world-size 8 \
+ --use-fp16 1 \
+ --tensorboard 1 \
+ --dataset "libritts" \
+ --base-lr 0.001 \
+ --max-duration 250 \
+ --token-file "data/tokens_libritts.txt" \
+ --manifest-dir "data/fbank_libritts" \
+ --teacher-model zipvoice/exp_zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt \
+ --num-epochs 6 \
+ --distill-stage "second" \
+ --exp-dir zipvoice/exp_zipvoice_distill_libritts
+```
+
+- Average checkpoints to produce the final model:
+
+```bash
+export PYTHONPATH=../../:$PYTHONPATH
+python3 ./zipvoice/generate_averaged_model.py \
+ --epoch 6 \
+ --avg 3 \
+ --distill 1 \
+ --token-file data/tokens_libritts.txt \
+ --dataset "libritts" \
+ --exp-dir ./zipvoice/exp_zipvoice_distill_libritts
+# The generated model is ./zipvoice/exp_zipvoice_distill_libritts/epoch-6-avg-3.pt
+```
+
+
+
+### 3. Inference with the trained model
+
+#### 3.1 Inference with the model trained on Emilia
+
+Expand to view inference commands.
+
+##### 3.1.1 ZipVoice model before distill:
+```bash
+export PYTHONPATH=../../:$PYTHONPATH
+python3 zipvoice/infer.py \
+ --checkpoint zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
+ --distill 0 \
+ --token-file "data/tokens_emilia.txt" \
+ --test-list test.tsv \
+ --res-dir results/test \
+ --num-step 16 \
+ --guidance-scale 1
+```
+
+##### 3.1.2 ZipVoice-Distill model before distill:
+```bash
+export PYTHONPATH=../../:$PYTHONPATH
+python3 zipvoice/infer.py \
+ --checkpoint zipvoice/exp_zipvoice_distill/checkpoint-2000.pt \
+ --distill 1 \
+ --token-file "data/tokens_emilia.txt" \
+ --test-list test.tsv \
+ --res-dir results/test_distill \
+ --num-step 8 \
+ --guidance-scale 3
+```
+
+
+
+#### 3.2 Inference with the model trained on LibriTTS
+
+
+Expand to view inference commands.
+
+##### 3.2.1 ZipVoice model before distill:
+```bash
+export PYTHONPATH=../../:$PYTHONPATH
+python3 zipvoice/infer.py \
+ --checkpoint zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt \
+ --distill 0 \
+ --token-file "data/tokens_libritts.txt" \
+ --test-list test.tsv \
+ --res-dir results/test_libritts \
+ --num-step 8 \
+ --guidance-scale 1 \
+ --target-rms 1.0 \
+ --t-shift 0.7
+```
+
+##### 3.2.2 ZipVoice-Distill model before distill
+
+```bash
+export PYTHONPATH=../../:$PYTHONPATH
+python3 zipvoice/infer.py \
+ --checkpoint zipvoice/exp_zipvoice_distill/epoch-6-avg-3.pt \
+ --distill 1 \
+ --token-file "data/tokens_libritts.txt" \
+ --test-list test.tsv \
+ --res-dir results/test_distill_libritts \
+ --num-step 4 \
+ --guidance-scale 3 \
+ --target-rms 1.0 \
+ --t-shift 0.7
+```
+
+
+### 4. Evaluation on benchmarks
+
+See [local/evaluate.sh](local/evaluate.sh) for details of objective metrics evaluation
+on three test sets, i.e., LibriSpeech-PC test-clean, Seed-TTS test-en and Seed-TTS test-zh.
+
+
+## Citation
+
+```bibtex
+@article{zhu-2025-zipvoice,
+ title={ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching},
+ author={Han Zhu and Wei Kang and Zengwei Yao and Liyong Guo and Fangjun Kuang and Zhaoqing Li and Weiji Zhuang and Long Lin and Daniel Povey}
+ journal={arXiv preprint arXiv:2506.13053},
+ year={2025},
+}
+```
\ No newline at end of file
diff --git a/egs/zipvoice/local/compute_fbank_libritts.py b/egs/zipvoice/local/compute_fbank_libritts.py
new file mode 100755
index 000000000..0c9f464ea
--- /dev/null
+++ b/egs/zipvoice/local/compute_fbank_libritts.py
@@ -0,0 +1,140 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao,)
+# 2024 The Chinese Univ. of HK (authors: Zengrui Jin)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the LibriTTS dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import argparse
+import logging
+import os
+from pathlib import Path
+from typing import Optional
+
+import torch
+from feature import TorchAudioFbank, TorchAudioFbankConfig
+from lhotse import CutSet, LilcomChunkyWriter
+from lhotse.recipes.utils import read_manifests_if_cached
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ help="""Dataset parts to compute fbank. If None, we will use all""",
+ )
+ parser.add_argument(
+ "--sampling-rate",
+ type=int,
+ default=24000,
+ help="""Sampling rate of the waveform for computing fbank,
+ the default value for LibriTTS is 24000, waveform files will be
+ resampled if a different sample rate is provided""",
+ )
+
+ return parser.parse_args()
+
+
+def compute_fbank_libritts(dataset: Optional[str] = None, sampling_rate: int = 24000):
+ src_dir = Path("data/manifests_libritts")
+ output_dir = Path("data/fbank_libritts")
+ num_jobs = min(32, os.cpu_count())
+
+ prefix = "libritts"
+ suffix = "jsonl.gz"
+ if dataset is None:
+ dataset_parts = (
+ "dev-clean",
+ "test-clean",
+ "train-clean-100",
+ "train-clean-360",
+ "train-other-500",
+ )
+ else:
+ dataset_parts = dataset.split(" ", -1)
+
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ assert len(manifests) == len(dataset_parts), (
+ len(manifests),
+ len(dataset_parts),
+ list(manifests.keys()),
+ dataset_parts,
+ )
+
+ config = TorchAudioFbankConfig(
+ sampling_rate=sampling_rate,
+ n_mels=100,
+ n_fft=1024,
+ hop_length=256,
+ )
+ extractor = TorchAudioFbank(config)
+
+ with get_executor() as ex: # Initialize the executor only once.
+ for partition, m in manifests.items():
+ cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ return
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+ if sampling_rate != 24000:
+ logging.info(f"Resampling waveforms to {sampling_rate}")
+ cut_set = cut_set.resample(sampling_rate)
+
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 80,
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+ cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ compute_fbank_libritts()
diff --git a/egs/zipvoice/local/evaluate.sh b/egs/zipvoice/local/evaluate.sh
new file mode 100644
index 000000000..fbc0eed9a
--- /dev/null
+++ b/egs/zipvoice/local/evaluate.sh
@@ -0,0 +1,102 @@
+export CUDA_VISIBLE_DEVICES="0"
+export PYTHONWARNINGS=ignore
+export PYTHONPATH=../../:$PYTHONPATH
+
+# Uncomment this if you have trouble connecting to HuggingFace
+# export HF_ENDPOINT=https://hf-mirror.com
+
+start_stage=1
+end_stage=3
+
+# Models used for SIM-o evaluation.
+# SV model wavlm_large_finetune.pth is downloaded from https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification
+# SSL model wavlm_large.pt is downloaded from https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt
+sv_model_path=model/UniSpeech/wavlm_large_finetune.pth
+wavlm_model_path=model/s3prl/wavlm_large.pt
+
+# Models used for UTMOS evaluation.
+# wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt
+# wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt
+utmos_model_path=model/huggingface/utmos/utmos.pt
+wav2vec_model_path=model/huggingface/utmos/wav2vec_small.pt
+
+
+if [ $start_stage -le 1 ] && [ $end_stage -ge 1 ]; then
+
+ echo "=====Evaluate for Seed-TTS test-en======="
+ test_list=testset/test_seedtts_en.tsv
+ wav_path=results/zipvoice_seedtts_en
+
+ echo $wav_path
+ echo "-----Computing SIM-o-----"
+ python3 local/evaluate_sim.py \
+ --sv-model-path ${sv_model_path} \
+ --ssl-model-path ${wavlm_model_path} \
+ --eval-path ${wav_path} \
+ --test-list ${test_list}
+
+ echo "-----Computing WER-----"
+ python3 local/evaluate_wer_seedtts.py \
+ --test-list ${test_list} \
+ --wav-path ${wav_path} \
+ --lang "en"
+
+ echo "-----Computing UTSMOS-----"
+ python3 local/evaluate_utmos.py \
+ --wav-path ${wav_path} \
+ --utmos-model-path ${utmos_model_path} \
+ --ssl-model-path ${wav2vec_model_path}
+
+fi
+
+if [ $start_stage -le 2 ] && [ $end_stage -ge 2 ]; then
+ echo "=====Evaluate for Seed-TTS test-zh======="
+ test_list=testset/test_seedtts_zh.tsv
+ wav_path=results/zipvoice_seedtts_zh
+
+ echo $wav_path
+ echo "-----Computing SIM-o-----"
+ python3 local/evaluate_sim.py \
+ --sv-model-path ${sv_model_path} \
+ --ssl-model-path ${wavlm_model_path} \
+ --eval-path ${wav_path} \
+ --test-list ${test_list}
+
+ echo "-----Computing WER-----"
+ python3 local/evaluate_wer_seedtts.py \
+ --test-list ${test_list} \
+ --wav-path ${wav_path} \
+ --lang "zh"
+
+ echo "-----Computing UTSMOS-----"
+ python3 local/evaluate_utmos.py \
+ --wav-path ${wav_path} \
+ --utmos-model-path ${utmos_model_path} \
+ --ssl-model-path ${wav2vec_model_path}
+fi
+
+if [ $start_stage -le 3 ] && [ $end_stage -ge 3 ]; then
+ echo "=====Evaluate for Librispeech test-clean======="
+ test_list=testset/test_librispeech_pc_test_clean.tsv
+ wav_path=results/zipvoice_librispeech_test_clean
+
+ echo $wav_path
+ echo "-----Computing SIM-o-----"
+ python3 local/evaluate_sim.py \
+ --sv-model-path ${sv_model_path} \
+ --ssl-model-path ${wavlm_model_path} \
+ --eval-path ${wav_path} \
+ --test-list ${test_list}
+
+ echo "-----Computing WER-----"
+ python3 local/evaluate_wer_hubert.py \
+ --test-list ${test_list} \
+ --wav-path ${wav_path} \
+
+ echo "-----Computing UTSMOS-----"
+ python3 local/evaluate_utmos.py \
+ --wav-path ${wav_path} \
+ --utmos-model-path ${utmos_model_path} \
+ --ssl-model-path ${wav2vec_model_path}
+
+fi
\ No newline at end of file
diff --git a/egs/zipvoice/local/evaluate_sim.py b/egs/zipvoice/local/evaluate_sim.py
new file mode 100644
index 000000000..df439cf2c
--- /dev/null
+++ b/egs/zipvoice/local/evaluate_sim.py
@@ -0,0 +1,508 @@
+"""
+Calculate pairwise Speaker Similarity betweeen two speech directories.
+SV model wavlm_large_finetune.pth is downloaded from
+ https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification
+SSL model wavlm_large.pt is downloaded from
+ https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt
+"""
+import argparse
+import logging
+import os
+
+import librosa
+import numpy as np
+import soundfile as sf
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from tqdm import tqdm
+
+logging.basicConfig(level=logging.INFO)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--eval-path", type=str, help="path of the evaluated speech directory"
+ )
+ parser.add_argument(
+ "--test-list",
+ type=str,
+ help="path of the file list that contains the corresponding "
+ "relationship between the prompt and evaluated speech. "
+ "The first column is the wav name and the third column is the prompt speech",
+ )
+ parser.add_argument(
+ "--sv-model-path",
+ type=str,
+ default="model/UniSpeech/wavlm_large_finetune.pth",
+ help="path of the wavlm-based ECAPA-TDNN model",
+ )
+ parser.add_argument(
+ "--ssl-model-path",
+ type=str,
+ default="model/s3prl/wavlm_large.pt",
+ help="path of the wavlm SSL model",
+ )
+ return parser
+
+
+class SpeakerSimilarity:
+ def __init__(
+ self,
+ sv_model_path="model/UniSpeech/wavlm_large_finetune.pth",
+ ssl_model_path="model/s3prl/wavlm_large.pt",
+ ):
+ """
+ Initialize
+ """
+ self.sample_rate = 16000
+ self.channels = 1
+ self.device = (
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ )
+ logging.info("[Speaker Similarity] Using device: {}".format(self.device))
+ self.model = ECAPA_TDNN_WAVLLM(
+ feat_dim=1024,
+ channels=512,
+ emb_dim=256,
+ sr=16000,
+ ssl_model_path=ssl_model_path,
+ )
+ state_dict = torch.load(
+ sv_model_path, map_location=lambda storage, loc: storage
+ )
+ self.model.load_state_dict(state_dict["model"], strict=False)
+ self.model.to(self.device)
+ self.model.eval()
+
+ def get_embeddings(self, wav_list, dtype="float32"):
+ """
+ Get embeddings
+ """
+
+ def _load_speech_task(fname, sample_rate):
+
+ wav_data, sr = sf.read(fname, dtype=dtype)
+ if sr != sample_rate:
+ wav_data = librosa.resample(
+ wav_data, orig_sr=sr, target_sr=self.sample_rate
+ )
+ wav_data = torch.from_numpy(wav_data)
+
+ return wav_data
+
+ embd_lst = []
+ for file_path in tqdm(wav_list):
+ speech = _load_speech_task(file_path, self.sample_rate)
+ speech = speech.to(self.device)
+ with torch.no_grad():
+ embd = self.model([speech])
+ embd_lst.append(embd)
+
+ return embd_lst
+
+ def score(
+ self,
+ eval_path,
+ test_list,
+ dtype="float32",
+ ):
+ """
+ Computes the Speaker Similarity (SIM-o) between two directories of speech files.
+
+ Parameters:
+ - eval_path (str): Path to the directory containing evaluation speech files.
+ - test_list (str): Path to the file containing the corresponding relationship
+ between prompt and evaluated speech.
+ - dtype (str, optional): Data type for loading speech. Default is "float32".
+
+ Returns:
+ - float: The Speaker Similarity (SIM-o) score between the two directories
+ of speech files.
+ """
+ prompt_wavs = []
+ eval_wavs = []
+ with open(test_list, "r") as fr:
+ lines = fr.readlines()
+ for line in lines:
+ wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
+ prompt_wavs.append(prompt_wav)
+ eval_wavs.append(os.path.join(eval_path, wav_name + ".wav"))
+ embds_prompt = self.get_embeddings(prompt_wavs, dtype=dtype)
+
+ embds_eval = self.get_embeddings(eval_wavs, dtype=dtype)
+
+ # Check if embeddings are empty
+ if len(embds_prompt) == 0:
+ logging.info("[Speaker Similarity] real set dir is empty, exiting...")
+ return -1
+ if len(embds_eval) == 0:
+ logging.info("[Speaker Similarity] eval set dir is empty, exiting...")
+ return -1
+
+ scores = []
+ for real_embd, eval_embd in zip(embds_prompt, embds_eval):
+ scores.append(
+ torch.nn.functional.cosine_similarity(real_embd, eval_embd, dim=-1)
+ .detach()
+ .cpu()
+ .numpy()
+ )
+
+ return np.mean(scores)
+
+
+# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
+
+""" Res2Conv1d + BatchNorm1d + ReLU
+"""
+
+
+class Res2Conv1dReluBn(nn.Module):
+ """
+ in_channels == out_channels == channels
+ """
+
+ def __init__(
+ self,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ dilation=1,
+ bias=True,
+ scale=4,
+ ):
+ super().__init__()
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
+ self.scale = scale
+ self.width = channels // scale
+ self.nums = scale if scale == 1 else scale - 1
+
+ self.convs = []
+ self.bns = []
+ for i in range(self.nums):
+ self.convs.append(
+ nn.Conv1d(
+ self.width,
+ self.width,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ bias=bias,
+ )
+ )
+ self.bns.append(nn.BatchNorm1d(self.width))
+ self.convs = nn.ModuleList(self.convs)
+ self.bns = nn.ModuleList(self.bns)
+
+ def forward(self, x):
+ out = []
+ spx = torch.split(x, self.width, 1)
+ for i in range(self.nums):
+ if i == 0:
+ sp = spx[i]
+ else:
+ sp = sp + spx[i]
+ # Order: conv -> relu -> bn
+ sp = self.convs[i](sp)
+ sp = self.bns[i](F.relu(sp))
+ out.append(sp)
+ if self.scale != 1:
+ out.append(spx[self.nums])
+ out = torch.cat(out, dim=1)
+
+ return out
+
+
+""" Conv1d + BatchNorm1d + ReLU
+"""
+
+
+class Conv1dReluBn(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ dilation=1,
+ bias=True,
+ ):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
+ )
+ self.bn = nn.BatchNorm1d(out_channels)
+
+ def forward(self, x):
+ return self.bn(F.relu(self.conv(x)))
+
+
+""" The SE connection of 1D case.
+"""
+
+
+class SE_Connect(nn.Module):
+ def __init__(self, channels, se_bottleneck_dim=128):
+ super().__init__()
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
+
+ def forward(self, x):
+ out = x.mean(dim=2)
+ out = F.relu(self.linear1(out))
+ out = torch.sigmoid(self.linear2(out))
+ out = x * out.unsqueeze(2)
+
+ return out
+
+
+""" SE-Res2Block of the ECAPA-TDNN architecture.
+"""
+
+
+# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
+# return nn.Sequential(
+# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
+# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
+# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
+# SE_Connect(channels)
+# )
+
+
+class SE_Res2Block(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ scale,
+ se_bottleneck_dim,
+ ):
+ super().__init__()
+ self.Conv1dReluBn1 = Conv1dReluBn(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(
+ out_channels, kernel_size, stride, padding, dilation, scale=scale
+ )
+ self.Conv1dReluBn2 = Conv1dReluBn(
+ out_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
+
+ self.shortcut = None
+ if in_channels != out_channels:
+ self.shortcut = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ residual = x
+ if self.shortcut:
+ residual = self.shortcut(x)
+
+ x = self.Conv1dReluBn1(x)
+ x = self.Res2Conv1dReluBn(x)
+ x = self.Conv1dReluBn2(x)
+ x = self.SE_Connect(x)
+
+ return x + residual
+
+
+""" Attentive weighted mean and standard deviation pooling.
+"""
+
+
+class AttentiveStatsPool(nn.Module):
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
+ super().__init__()
+ self.global_context_att = global_context_att
+
+ # Use Conv1d with stride == 1 rather than Linear,
+ # then we don't need to transpose inputs.
+ if global_context_att:
+ self.linear1 = nn.Conv1d(
+ in_dim * 3, attention_channels, kernel_size=1
+ ) # equals W and b in the paper
+ else:
+ self.linear1 = nn.Conv1d(
+ in_dim, attention_channels, kernel_size=1
+ ) # equals W and b in the paper
+ self.linear2 = nn.Conv1d(
+ attention_channels, in_dim, kernel_size=1
+ ) # equals V and k in the paper
+
+ def forward(self, x):
+
+ if self.global_context_att:
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
+ context_std = torch.sqrt(
+ torch.var(x, dim=-1, keepdim=True) + 1e-10
+ ).expand_as(x)
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
+ else:
+ x_in = x
+
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
+ alpha = torch.tanh(self.linear1(x_in))
+ # alpha = F.relu(self.linear1(x_in))
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
+ mean = torch.sum(alpha * x, dim=2)
+ residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
+ std = torch.sqrt(residuals.clamp(min=1e-9))
+ return torch.cat([mean, std], dim=1)
+
+
+class ECAPA_TDNN_WAVLLM(nn.Module):
+ def __init__(
+ self,
+ feat_dim=80,
+ channels=512,
+ emb_dim=192,
+ global_context_att=False,
+ sr=16000,
+ ssl_model_path=None,
+ ):
+ super().__init__()
+ self.sr = sr
+
+ if ssl_model_path is None:
+ self.feature_extract = torch.hub.load("s3prl/s3prl", "wavlm_large")
+ else:
+ self.feature_extract = torch.hub.load(
+ os.path.dirname(ssl_model_path),
+ "wavlm_local",
+ source="local",
+ ckpt=ssl_model_path,
+ )
+
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
+ self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
+ ):
+ self.feature_extract.model.encoder.layers[
+ 23
+ ].self_attn.fp32_attention = False
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
+ self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
+ ):
+ self.feature_extract.model.encoder.layers[
+ 11
+ ].self_attn.fp32_attention = False
+
+ self.feat_num = self.get_feat_num()
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
+
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
+ # self.channels = [channels] * 4 + [channels * 3]
+ self.channels = [channels] * 4 + [1536]
+
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
+ self.layer2 = SE_Res2Block(
+ self.channels[0],
+ self.channels[1],
+ kernel_size=3,
+ stride=1,
+ padding=2,
+ dilation=2,
+ scale=8,
+ se_bottleneck_dim=128,
+ )
+ self.layer3 = SE_Res2Block(
+ self.channels[1],
+ self.channels[2],
+ kernel_size=3,
+ stride=1,
+ padding=3,
+ dilation=3,
+ scale=8,
+ se_bottleneck_dim=128,
+ )
+ self.layer4 = SE_Res2Block(
+ self.channels[2],
+ self.channels[3],
+ kernel_size=3,
+ stride=1,
+ padding=4,
+ dilation=4,
+ scale=8,
+ se_bottleneck_dim=128,
+ )
+
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
+ cat_channels = channels * 3
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
+ self.pooling = AttentiveStatsPool(
+ self.channels[-1],
+ attention_channels=128,
+ global_context_att=global_context_att,
+ )
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
+
+ def get_feat_num(self):
+ self.feature_extract.eval()
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
+ with torch.no_grad():
+ features = self.feature_extract(wav)
+ select_feature = features["hidden_states"]
+ if isinstance(select_feature, (list, tuple)):
+ return len(select_feature)
+ else:
+ return 1
+
+ def get_feat(self, x):
+ with torch.no_grad():
+ x = self.feature_extract([sample for sample in x])
+
+ x = x["hidden_states"]
+ if isinstance(x, (list, tuple)):
+ x = torch.stack(x, dim=0)
+ else:
+ x = x.unsqueeze(0)
+ norm_weights = (
+ F.softmax(self.feature_weight, dim=-1)
+ .unsqueeze(-1)
+ .unsqueeze(-1)
+ .unsqueeze(-1)
+ )
+ x = (norm_weights * x).sum(dim=0)
+ x = torch.transpose(x, 1, 2) + 1e-6
+
+ x = self.instance_norm(x)
+ return x
+
+ def forward(self, x):
+ x = self.get_feat(x)
+
+ out1 = self.layer1(x)
+ out2 = self.layer2(out1)
+ out3 = self.layer3(out2)
+ out4 = self.layer4(out3)
+
+ out = torch.cat([out2, out3, out4], dim=1)
+ out = F.relu(self.conv(out))
+ out = self.bn(self.pooling(out))
+ out = self.linear(out)
+
+ return out
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args = parser.parse_args()
+ SIM = SpeakerSimilarity(
+ sv_model_path=args.sv_model_path, ssl_model_path=args.ssl_model_path
+ )
+ score = SIM.score(args.eval_path, args.test_list)
+ logging.info(f"SIM-o score: {score:.3f}")
diff --git a/egs/zipvoice/local/evaluate_utmos.py b/egs/zipvoice/local/evaluate_utmos.py
new file mode 100644
index 000000000..369e139c1
--- /dev/null
+++ b/egs/zipvoice/local/evaluate_utmos.py
@@ -0,0 +1,294 @@
+"""
+Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system
+adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo
+
+# Download model checkpoints
+wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt
+wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt
+"""
+
+import argparse
+import logging
+import os
+
+import fairseq
+import librosa
+import numpy as np
+import pytorch_lightning as pl
+import soundfile as sf
+import torch
+import torch.nn as nn
+from tqdm import tqdm
+
+logging.basicConfig(level=logging.INFO)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--wav-path", type=str, help="path of the evaluated speech directory"
+ )
+ parser.add_argument(
+ "--utmos-model-path",
+ type=str,
+ default="model/huggingface/utmos/utmos.pt",
+ help="path of the UTMOS model",
+ )
+ parser.add_argument(
+ "--ssl-model-path",
+ type=str,
+ default="model/huggingface/utmos/wav2vec_small.pt",
+ help="path of the wav2vec SSL model",
+ )
+ return parser
+
+
+class UTMOSScore:
+ """Predicting score for each audio clip."""
+
+ def __init__(self, utmos_model_path, ssl_model_path):
+ self.sample_rate = 16000
+ self.device = (
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ )
+ self.model = (
+ BaselineLightningModule.load_from_checkpoint(
+ utmos_model_path, ssl_model_path=ssl_model_path
+ )
+ .eval()
+ .to(self.device)
+ )
+
+ def score(self, wavs: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ wavs: waveforms to be evaluated. When len(wavs) == 1 or 2,
+ the model processes the input as a single audio clip. The model
+ performs batch processing when len(wavs) == 3.
+ """
+ if len(wavs.shape) == 1:
+ out_wavs = wavs.unsqueeze(0).unsqueeze(0)
+ elif len(wavs.shape) == 2:
+ out_wavs = wavs.unsqueeze(0)
+ elif len(wavs.shape) == 3:
+ out_wavs = wavs
+ else:
+ raise ValueError("Dimension of input tensor needs to be <= 3.")
+ bs = out_wavs.shape[0]
+ batch = {
+ "wav": out_wavs,
+ "domains": torch.zeros(bs, dtype=torch.int).to(self.device),
+ "judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288,
+ }
+ with torch.no_grad():
+ output = self.model(batch)
+
+ return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3
+
+ def score_dir(self, dir, dtype="float32"):
+ def _load_speech_task(fname, sample_rate):
+
+ wav_data, sr = sf.read(fname, dtype=dtype)
+ if sr != sample_rate:
+ wav_data = librosa.resample(
+ wav_data, orig_sr=sr, target_sr=self.sample_rate
+ )
+ wav_data = torch.from_numpy(wav_data)
+
+ return wav_data
+
+ score_lst = []
+ for fname in tqdm(os.listdir(dir)):
+ speech = _load_speech_task(os.path.join(dir, fname), self.sample_rate)
+ speech = speech.to(self.device)
+ with torch.no_grad():
+ score = self.score(speech)
+ score_lst.append(score.item())
+ return np.mean(score_lst)
+
+
+def load_ssl_model(ckpt_path="wav2vec_small.pt"):
+ SSL_OUT_DIM = 768
+ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
+ [ckpt_path]
+ )
+ ssl_model = model[0]
+ ssl_model.remove_pretraining_modules()
+ return SSL_model(ssl_model, SSL_OUT_DIM)
+
+
+class BaselineLightningModule(pl.LightningModule):
+ def __init__(self, ssl_model_path):
+ super().__init__()
+ self.construct_model(ssl_model_path)
+ self.save_hyperparameters()
+
+ def construct_model(self, ssl_model_path):
+ self.feature_extractors = nn.ModuleList(
+ [
+ load_ssl_model(ckpt_path=ssl_model_path),
+ DomainEmbedding(3, 128),
+ ]
+ )
+ output_dim = sum(
+ [
+ feature_extractor.get_output_dim()
+ for feature_extractor in self.feature_extractors
+ ]
+ )
+ output_layers = [
+ LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim)
+ ]
+ output_dim = output_layers[-1].get_output_dim()
+ output_layers.append(
+ Projection(
+ hidden_dim=2048,
+ activation=torch.nn.ReLU(),
+ range_clipping=False,
+ input_dim=output_dim,
+ )
+ )
+
+ self.output_layers = nn.ModuleList(output_layers)
+
+ def forward(self, inputs):
+ outputs = {}
+ for feature_extractor in self.feature_extractors:
+ outputs.update(feature_extractor(inputs))
+ x = outputs
+ for output_layer in self.output_layers:
+ x = output_layer(x, inputs)
+ return x
+
+
+class SSL_model(nn.Module):
+ def __init__(self, ssl_model, ssl_out_dim) -> None:
+ super(SSL_model, self).__init__()
+ self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
+
+ def forward(self, batch):
+ wav = batch["wav"]
+ wav = wav.squeeze(1) # [batches, wav_len]
+ res = self.ssl_model(wav, mask=False, features_only=True)
+ x = res["x"]
+ return {"ssl-feature": x}
+
+ def get_output_dim(self):
+ return self.ssl_out_dim
+
+
+class DomainEmbedding(nn.Module):
+ def __init__(self, n_domains, domain_dim) -> None:
+ super().__init__()
+ self.embedding = nn.Embedding(n_domains, domain_dim)
+ self.output_dim = domain_dim
+
+ def forward(self, batch):
+ return {"domain-feature": self.embedding(batch["domains"])}
+
+ def get_output_dim(self):
+ return self.output_dim
+
+
+class LDConditioner(nn.Module):
+ """
+ Conditions ssl output by listener embedding
+ """
+
+ def __init__(self, input_dim, judge_dim, num_judges=None):
+ super().__init__()
+ self.input_dim = input_dim
+ self.judge_dim = judge_dim
+ self.num_judges = num_judges
+ assert num_judges != None
+ self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
+ # concat [self.output_layer, phoneme features]
+
+ self.decoder_rnn = nn.LSTM(
+ input_size=self.input_dim + self.judge_dim,
+ hidden_size=512,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=True,
+ ) # linear?
+ self.out_dim = self.decoder_rnn.hidden_size * 2
+
+ def get_output_dim(self):
+ return self.out_dim
+
+ def forward(self, x, batch):
+ judge_ids = batch["judge_id"]
+ if "phoneme-feature" in x.keys():
+ concatenated_feature = torch.cat(
+ (
+ x["ssl-feature"],
+ x["phoneme-feature"]
+ .unsqueeze(1)
+ .expand(-1, x["ssl-feature"].size(1), -1),
+ ),
+ dim=2,
+ )
+ else:
+ concatenated_feature = x["ssl-feature"]
+ if "domain-feature" in x.keys():
+ concatenated_feature = torch.cat(
+ (
+ concatenated_feature,
+ x["domain-feature"]
+ .unsqueeze(1)
+ .expand(-1, concatenated_feature.size(1), -1),
+ ),
+ dim=2,
+ )
+ if judge_ids != None:
+ concatenated_feature = torch.cat(
+ (
+ concatenated_feature,
+ self.judge_embedding(judge_ids)
+ .unsqueeze(1)
+ .expand(-1, concatenated_feature.size(1), -1),
+ ),
+ dim=2,
+ )
+ decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
+ return decoder_output
+
+
+class Projection(nn.Module):
+ def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
+ super(Projection, self).__init__()
+ self.range_clipping = range_clipping
+ output_dim = 1
+ if range_clipping:
+ self.proj = nn.Tanh()
+
+ self.net = nn.Sequential(
+ nn.Linear(input_dim, hidden_dim),
+ activation,
+ nn.Dropout(0.3),
+ nn.Linear(hidden_dim, output_dim),
+ )
+ self.output_dim = output_dim
+
+ def forward(self, x, batch):
+ output = self.net(x)
+
+ # range clipping
+ if self.range_clipping:
+ return self.proj(output) * 2.0 + 3
+ else:
+ return output
+
+ def get_output_dim(self):
+ return self.output_dim
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args = parser.parse_args()
+ UTMOS = UTMOSScore(
+ utmos_model_path=args.utmos_model_path, ssl_model_path=args.ssl_model_path
+ )
+ score = UTMOS.score_dir(args.wav_path)
+ logging.info(f"UTMOS score: {score:.2f}")
diff --git a/egs/zipvoice/local/evaluate_wer_hubert.py b/egs/zipvoice/local/evaluate_wer_hubert.py
new file mode 100644
index 000000000..d30346e67
--- /dev/null
+++ b/egs/zipvoice/local/evaluate_wer_hubert.py
@@ -0,0 +1,172 @@
+"""
+Calculate WER with Hubert models.
+"""
+import argparse
+import os
+import re
+from pathlib import Path
+
+import librosa
+import numpy as np
+import soundfile as sf
+import torch
+from jiwer import compute_measures
+from tqdm import tqdm
+from transformers import pipeline
+
+
+def get_parser():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--wav-path", type=str, help="path of the speech directory")
+ parser.add_argument(
+ "--decode-path",
+ type=str,
+ default=None,
+ help="path of the output file of WER information",
+ )
+ parser.add_argument(
+ "--model-path",
+ type=str,
+ default=None,
+ help="path of the local hubert model, e.g., model/huggingface/hubert-large-ls960-ft",
+ )
+ parser.add_argument(
+ "--test-list",
+ type=str,
+ default="test.tsv",
+ help="path of the transcript tsv file, where the first column "
+ "is the wav name and the last column is the transcript",
+ )
+ parser.add_argument(
+ "--batch-size", type=int, default=16, help="decoding batch size"
+ )
+ return parser
+
+
+def post_process(text: str):
+ text = text.replace("‘", "'")
+ text = text.replace("’", "'")
+ text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def process_one(hypo, truth):
+ truth = post_process(truth)
+ hypo = post_process(hypo)
+
+ measures = compute_measures(truth, hypo)
+ word_num = len(truth.split(" "))
+ wer = measures["wer"]
+ subs = measures["substitutions"]
+ dele = measures["deletions"]
+ inse = measures["insertions"]
+ return (truth, hypo, wer, subs, dele, inse, word_num)
+
+
+class SpeechEvalDataset(torch.utils.data.Dataset):
+ def __init__(self, wav_path: str, test_list: str):
+ super().__init__()
+ self.wav_name = []
+ self.wav_paths = []
+ self.transcripts = []
+ with Path(test_list).open("r", encoding="utf8") as f:
+ meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
+ for item in meta:
+ self.wav_name.append(item[0])
+ self.wav_paths.append(Path(wav_path, item[0] + ".wav"))
+ self.transcripts.append(item[-1])
+
+ def __len__(self):
+ return len(self.wav_paths)
+
+ def __getitem__(self, index: int):
+ wav, sampling_rate = sf.read(self.wav_paths[index])
+ item = {
+ "array": librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000),
+ "sampling_rate": 16000,
+ "reference": self.transcripts[index],
+ "wav_name": self.wav_name[index],
+ }
+ return item
+
+
+def main(test_list, wav_path, model_path, decode_path, batch_size, device):
+
+ if model_path is not None:
+ pipe = pipeline(
+ "automatic-speech-recognition",
+ model=model_path,
+ device=device,
+ tokenizer=model_path,
+ )
+ else:
+ pipe = pipeline(
+ "automatic-speech-recognition",
+ model="facebook/hubert-large-ls960-ft",
+ device=device,
+ )
+
+ dataset = SpeechEvalDataset(wav_path, test_list)
+
+ bar = tqdm(
+ pipe(
+ dataset,
+ generate_kwargs={"language": "english", "task": "transcribe"},
+ batch_size=batch_size,
+ ),
+ total=len(dataset),
+ )
+
+ wers = []
+ inses = []
+ deles = []
+ subses = []
+ word_nums = 0
+ if decode_path:
+ decode_dir = os.path.dirname(decode_path)
+ if not os.path.exists(decode_dir):
+ os.makedirs(decode_dir)
+ fout = open(decode_path, "w")
+ for out in bar:
+ wav_name = out["wav_name"][0]
+ transcription = post_process(out["text"].strip())
+ text_ref = post_process(out["reference"][0].strip())
+ truth, hypo, wer, subs, dele, inse, word_num = process_one(
+ transcription, text_ref
+ )
+ if decode_path:
+ fout.write(f"{wav_name}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
+ wers.append(float(wer))
+ inses.append(float(inse))
+ deles.append(float(dele))
+ subses.append(float(subs))
+ word_nums += word_num
+
+ wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
+ subs = round(np.mean(subses) * 100, 3)
+ dele = round(np.mean(deles) * 100, 3)
+ inse = round(np.mean(inses) * 100, 3)
+ print(f"WER: {wer}%\n")
+ if decode_path:
+ fout.write(f"WER: {wer}%\n")
+ fout.flush()
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args = parser.parse_args()
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+ else:
+ device = torch.device("cpu")
+ main(
+ args.test_list,
+ args.wav_path,
+ args.model_path,
+ args.decode_path,
+ args.batch_size,
+ device,
+ )
diff --git a/egs/zipvoice/local/evaluate_wer_seedtts.py b/egs/zipvoice/local/evaluate_wer_seedtts.py
new file mode 100644
index 000000000..f7e256387
--- /dev/null
+++ b/egs/zipvoice/local/evaluate_wer_seedtts.py
@@ -0,0 +1,181 @@
+"""
+Calculate WER with Whisper-large-v3 or Paraformer models,
+following Seed-TTS https://github.com/BytedanceSpeech/seed-tts-eval
+"""
+
+import argparse
+import os
+import string
+
+import numpy as np
+import scipy
+import soundfile as sf
+import torch
+import zhconv
+from funasr import AutoModel
+from jiwer import compute_measures
+from tqdm import tqdm
+from transformers import WhisperForConditionalGeneration, WhisperProcessor
+from zhon.hanzi import punctuation
+
+
+def get_parser():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--wav-path", type=str, help="path of the speech directory")
+ parser.add_argument(
+ "--decode-path",
+ type=str,
+ default=None,
+ help="path of the output file of WER information",
+ )
+ parser.add_argument(
+ "--model-path",
+ type=str,
+ default=None,
+ help="path of the local whisper and paraformer model, "
+ "e.g., whisper: model/huggingface/whisper-large-v3/, "
+ "paraformer: model/huggingface/paraformer-zh/",
+ )
+ parser.add_argument(
+ "--test-list",
+ type=str,
+ default="test.tsv",
+ help="path of the transcript tsv file, where the first column "
+ "is the wav name and the last column is the transcript",
+ )
+ parser.add_argument("--lang", type=str, help="decoded language, zh or en")
+ return parser
+
+
+def load_en_model(model_path):
+ if model_path is None:
+ model_path = "openai/whisper-large-v3"
+ processor = WhisperProcessor.from_pretrained(model_path)
+ model = WhisperForConditionalGeneration.from_pretrained(model_path)
+ return processor, model
+
+
+def load_zh_model(model_path):
+ if model_path is None:
+ model_path = "paraformer-zh"
+ model = AutoModel(model=model_path)
+ return model
+
+
+def process_one(hypo, truth, lang):
+ punctuation_all = punctuation + string.punctuation
+ for x in punctuation_all:
+ if x == "'":
+ continue
+ truth = truth.replace(x, "")
+ hypo = hypo.replace(x, "")
+
+ truth = truth.replace(" ", " ")
+ hypo = hypo.replace(" ", " ")
+
+ if lang == "zh":
+ truth = " ".join([x for x in truth])
+ hypo = " ".join([x for x in hypo])
+ elif lang == "en":
+ truth = truth.lower()
+ hypo = hypo.lower()
+ else:
+ raise NotImplementedError
+
+ measures = compute_measures(truth, hypo)
+ word_num = len(truth.split(" "))
+ wer = measures["wer"]
+ subs = measures["substitutions"]
+ dele = measures["deletions"]
+ inse = measures["insertions"]
+ return (truth, hypo, wer, subs, dele, inse, word_num)
+
+
+def main(test_list, wav_path, model_path, decode_path, lang, device):
+ if lang == "en":
+ processor, model = load_en_model(model_path)
+ model.to(device)
+ elif lang == "zh":
+ model = load_zh_model(model_path)
+ params = []
+ for line in open(test_list).readlines():
+ line = line.strip()
+ items = line.split("\t")
+ wav_name, text_ref = items[0], items[-1]
+ file_path = os.path.join(wav_path, wav_name + ".wav")
+ assert os.path.exists(file_path), f"{file_path}"
+
+ params.append((file_path, text_ref))
+ wers = []
+ inses = []
+ deles = []
+ subses = []
+ word_nums = 0
+ if decode_path:
+ decode_dir = os.path.dirname(decode_path)
+ if not os.path.exists(decode_dir):
+ os.makedirs(decode_dir)
+ fout = open(decode_path, "w")
+ for wav_path, text_ref in tqdm(params):
+ if lang == "en":
+ wav, sr = sf.read(wav_path)
+ if sr != 16000:
+ wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr))
+ input_features = processor(
+ wav, sampling_rate=16000, return_tensors="pt"
+ ).input_features
+ input_features = input_features.to(device)
+ forced_decoder_ids = processor.get_decoder_prompt_ids(
+ language="english", task="transcribe"
+ )
+ predicted_ids = model.generate(
+ input_features, forced_decoder_ids=forced_decoder_ids
+ )
+ transcription = processor.batch_decode(
+ predicted_ids, skip_special_tokens=True
+ )[0]
+ elif lang == "zh":
+ res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True)
+ transcription = res[0]["text"]
+ transcription = zhconv.convert(transcription, "zh-cn")
+
+ truth, hypo, wer, subs, dele, inse, word_num = process_one(
+ transcription, text_ref, lang
+ )
+ if decode_path:
+ fout.write(f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
+ wers.append(float(wer))
+ inses.append(float(inse))
+ deles.append(float(dele))
+ subses.append(float(subs))
+ word_nums += word_num
+
+ wer_avg = round(np.mean(wers) * 100, 3)
+ wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
+ subs = round(np.mean(subses) * 100, 3)
+ dele = round(np.mean(deles) * 100, 3)
+ inse = round(np.mean(inses) * 100, 3)
+ print(f"Seed-TTS WER: {wer_avg}%\n")
+ print(f"WER: {wer}%\n")
+ if decode_path:
+ fout.write(f"SeedTTS WER: {wer_avg}%\n")
+ fout.write(f"WER: {wer}%\n")
+ fout.flush()
+
+
+if __name__ == "__main__":
+ parser = get_parser()
+ args = parser.parse_args()
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+ else:
+ device = torch.device("cpu")
+ main(
+ args.test_list,
+ args.wav_path,
+ args.model_path,
+ args.decode_path,
+ args.lang,
+ device,
+ )
diff --git a/egs/zipvoice/local/feature.py b/egs/zipvoice/local/feature.py
new file mode 100644
index 000000000..e7d484d10
--- /dev/null
+++ b/egs/zipvoice/local/feature.py
@@ -0,0 +1,135 @@
+#!/usr/bin/env python3
+# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torchaudio
+from lhotse.features.base import FeatureExtractor, register_extractor
+from lhotse.utils import Seconds, compute_num_frames
+
+
+class MelSpectrogramFeatures(nn.Module):
+ def __init__(
+ self,
+ sampling_rate=24000,
+ n_mels=100,
+ n_fft=1024,
+ hop_length=256,
+ ):
+ super().__init__()
+
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
+ sample_rate=sampling_rate,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ n_mels=n_mels,
+ center=True,
+ power=1,
+ )
+
+ def forward(self, inp):
+ assert len(inp.shape) == 2
+
+ mel = self.mel_spec(inp)
+ logmel = mel.clamp(min=1e-7).log()
+ return logmel
+
+
+@dataclass
+class TorchAudioFbankConfig:
+ sampling_rate: int
+ n_mels: int
+ n_fft: int
+ hop_length: int
+
+
+@register_extractor
+class TorchAudioFbank(FeatureExtractor):
+
+ name = "TorchAudioFbank"
+ config_type = TorchAudioFbankConfig
+
+ def __init__(self, config):
+ super().__init__(config=config)
+
+ def _feature_fn(self, sample):
+ fbank = MelSpectrogramFeatures(
+ sampling_rate=self.config.sampling_rate,
+ n_mels=self.config.n_mels,
+ n_fft=self.config.n_fft,
+ hop_length=self.config.hop_length,
+ )
+
+ return fbank(sample)
+
+ @property
+ def device(self) -> Union[str, torch.device]:
+ return self.config.device
+
+ def feature_dim(self, sampling_rate: int) -> int:
+ return self.config.n_mels
+
+ def extract(
+ self,
+ samples: Union[np.ndarray, torch.Tensor],
+ sampling_rate: int,
+ ) -> Union[np.ndarray, torch.Tensor]:
+ # Check for sampling rate compatibility.
+ expected_sr = self.config.sampling_rate
+ assert sampling_rate == expected_sr, (
+ f"Mismatched sampling rate: extractor expects {expected_sr}, "
+ f"got {sampling_rate}"
+ )
+ is_numpy = False
+ if not isinstance(samples, torch.Tensor):
+ samples = torch.from_numpy(samples)
+ is_numpy = True
+
+ if len(samples.shape) == 1:
+ samples = samples.unsqueeze(0)
+ assert samples.ndim == 2, samples.shape
+ assert samples.shape[0] == 1, samples.shape
+
+ mel = self._feature_fn(samples).squeeze().t()
+
+ assert mel.ndim == 2, mel.shape
+ assert mel.shape[1] == self.config.n_mels, mel.shape
+
+ num_frames = compute_num_frames(
+ samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
+ )
+
+ if mel.shape[0] > num_frames:
+ mel = mel[:num_frames]
+ elif mel.shape[0] < num_frames:
+ mel = mel.unsqueeze(0)
+ mel = torch.nn.functional.pad(
+ mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
+ ).squeeze(0)
+
+ if is_numpy:
+ return mel.cpu().numpy()
+ else:
+ return mel
+
+ @property
+ def frame_shift(self) -> Seconds:
+ return self.config.hop_length / self.config.sampling_rate
diff --git a/egs/zipvoice/local/prepare_libritts.sh b/egs/zipvoice/local/prepare_libritts.sh
new file mode 100755
index 000000000..b35065bb1
--- /dev/null
+++ b/egs/zipvoice/local/prepare_libritts.sh
@@ -0,0 +1,88 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+stage=0
+stop_stage=5
+sampling_rate=24000
+nj=32
+
+dl_dir=$PWD/download
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ # If you have pre-downloaded it to /path/to/LibriTTS,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS
+ #
+ if [ ! -d $dl_dir/LibriTTS ]; then
+ lhotse download libritts $dl_dir
+ fi
+
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare LibriTTS manifest"
+ # We assume that you have downloaded the LibriTTS corpus
+ # to $dl_dir/LibriTTS
+ mkdir -p data/manifests_libritts
+ if [ ! -e data/manifests_libritts/.libritts.done ]; then
+ lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests_libritts
+ touch data/manifests_libritts/.libritts.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Compute Fbank for LibriTTS"
+ mkdir -p data/fbank
+ if [ ! -e data/fbank_libritts/.libritts.done ]; then
+ ./local/compute_fbank_libritts.py --sampling-rate $sampling_rate
+ touch data/fbank_libritts/.libritts.done
+ fi
+
+ # Here we shuffle and combine the train-clean-100, train-clean-360 and
+ # train-other-500 together to form the training set.
+ if [ ! -f data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz ]; then
+ cat <(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-100.jsonl.gz) \
+ <(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-360.jsonl.gz) \
+ <(gunzip -c data/fbank_libritts/libritts_cuts_train-other-500.jsonl.gz) | \
+ shuf | gzip -c > data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz
+ fi
+
+ if [ ! -f data/fbank_libritts/libritts_cuts_train-clean-460.jsonl.gz ]; then
+ cat <(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-100.jsonl.gz) \
+ <(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-360.jsonl.gz) | \
+ shuf | gzip -c > data/fbank_libritts/libritts_cuts_train-clean-460.jsonl.gz
+ fi
+
+ if [ ! -e data/fbank_libritts/.libritts-validated.done ]; then
+ log "Validating data/fbank for LibriTTS"
+ ./local/validate_manifest.py \
+ data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz
+ touch data/fbank_libritts/.libritts-validated.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 4: Generate token file"
+ if [ ! -e data/tokens_libritts.txt ]; then
+ ./local/prepare_token_file_libritts.py --tokens data/tokens_libritts.txt
+ fi
+fi
\ No newline at end of file
diff --git a/egs/zipvoice/local/prepare_token_file_emilia.py b/egs/zipvoice/local/prepare_token_file_emilia.py
new file mode 100644
index 000000000..68af8d397
--- /dev/null
+++ b/egs/zipvoice/local/prepare_token_file_emilia.py
@@ -0,0 +1,90 @@
+#!/usr/bin/env python3
+# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file generates the file that maps tokens to IDs.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import List
+
+from piper_phonemize import get_espeak_map
+from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--tokens",
+ type=Path,
+ default=Path("data/tokens_emilia.txt"),
+ help="Path to the dict that maps the text tokens to IDs",
+ )
+
+ parser.add_argument(
+ "--pinyin",
+ type=Path,
+ default=Path("local/pinyin.txt"),
+ help="Path to the all unique pinyin",
+ )
+
+ return parser.parse_args()
+
+
+def get_pinyin_tokens(pinyin: Path) -> List[str]:
+ phones = set()
+ with open(pinyin, "r") as f:
+ for line in f:
+ x = line.strip()
+ initial = to_initials(x, strict=False)
+ # don't want to share tokens with espeak tokens, so use tone3 style
+ finals = to_finals_tone3(x, strict=False, neutral_tone_with_five=True)
+ if initial != "":
+ # don't want to share tokens with espeak tokens, so add a '0' after each initial
+ phones.add(initial + "0")
+ if finals != "":
+ phones.add(finals)
+ return sorted(phones)
+
+
+def get_token2id(args):
+ """Get a dict that maps token to IDs, and save it to the given filename."""
+ all_tokens = get_espeak_map() # token: [token_id]
+ all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()}
+ # sort by token_id
+ all_tokens = sorted(all_tokens.items(), key=lambda x: x[1])
+
+ all_pinyin = get_pinyin_tokens(args.pinyin)
+ with open(args.tokens, "w", encoding="utf-8") as f:
+ for token, token_id in all_tokens:
+ f.write(f"{token} {token_id}\n")
+ num_espeak_tokens = len(all_tokens)
+ for i, pinyin in enumerate(all_pinyin):
+ f.write(f"{pinyin} {num_espeak_tokens + i}\n")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+ get_token2id(args)
diff --git a/egs/zipvoice/local/prepare_token_file_libritts.py b/egs/zipvoice/local/prepare_token_file_libritts.py
new file mode 100644
index 000000000..374b02613
--- /dev/null
+++ b/egs/zipvoice/local/prepare_token_file_libritts.py
@@ -0,0 +1,31 @@
+import re
+from collections import Counter
+
+from lhotse import load_manifest_lazy
+
+
+def prepare_tokens(manifest_file, token_file):
+ counter = Counter()
+ manifest = load_manifest_lazy(manifest_file)
+ for cut in manifest:
+ line = re.sub(r"\s+", " ", cut.supervisions[0].text)
+ counter.update(line)
+
+ unique_chars = set(counter.keys())
+
+ if "_" in unique_chars:
+ unique_chars.remove("_")
+
+ sorted_chars = sorted(unique_chars, key=lambda char: counter[char], reverse=True)
+
+ result = ["_"] + sorted_chars
+
+ with open(token_file, "w", encoding="utf-8") as file:
+ for index, char in enumerate(result):
+ file.write(f"{char} {index}\n")
+
+
+if __name__ == "__main__":
+ manifest_file = "data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz"
+ output_token_file = "data/tokens_libritts.txt"
+ prepare_tokens(manifest_file, output_token_file)
diff --git a/egs/zipvoice/local/prepare_tokens_emilia.py b/egs/zipvoice/local/prepare_tokens_emilia.py
new file mode 100644
index 000000000..023d57524
--- /dev/null
+++ b/egs/zipvoice/local/prepare_tokens_emilia.py
@@ -0,0 +1,192 @@
+#!/usr/bin/env python3
+# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao,
+# Zengrui Jin,
+# Wei Kang)
+# 2024 Tsinghua University (authors: Zengrui Jin,)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file reads the texts in given manifest and save the new cuts with phoneme tokens.
+"""
+
+import argparse
+import glob
+import logging
+import re
+from concurrent.futures import ProcessPoolExecutor as Pool
+from pathlib import Path
+from typing import List
+
+import jieba
+from lhotse import load_manifest_lazy
+from tokenizer import Tokenizer, is_alphabet, is_chinese, is_hangul, is_japanese
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--subset",
+ type=str,
+ help="Subset of emilia, (ZH, EN, etc.)",
+ )
+
+ parser.add_argument(
+ "--jobs",
+ type=int,
+ default=50,
+ help="Number of jobs to processing.",
+ )
+
+ parser.add_argument(
+ "--source-dir",
+ type=str,
+ default="data/manifests_emilia/splits",
+ help="The source directory of manifest files.",
+ )
+
+ parser.add_argument(
+ "--dest-dir",
+ type=str,
+ help="The destination directory of manifest files.",
+ )
+
+ return parser.parse_args()
+
+
+def tokenize_by_CJK_char(line: str) -> List[str]:
+ """
+ Tokenize a line of text with CJK char.
+
+ Note: All return characters will be upper case.
+
+ Example:
+ input = "你好世界是 hello world 的中文"
+ output = [你, 好, 世, 界, 是, HELLO, WORLD, 的, 中, 文]
+
+ Args:
+ line:
+ The input text.
+
+ Return:
+ A new string tokenize by CJK char.
+ """
+ # The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
+ pattern = re.compile(
+ r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])"
+ )
+ chars = pattern.split(line.strip().upper())
+ char_list = []
+ for w in chars:
+ if w.strip():
+ char_list += w.strip().split()
+ return char_list
+
+
+def prepare_tokens_emilia(file_name: str, input_dir: Path, output_dir: Path):
+ logging.info(f"Processing {file_name}")
+ if (output_dir / file_name).is_file():
+ logging.info(f"{file_name} exists, skipping.")
+ return
+ jieba.setLogLevel(logging.INFO)
+ tokenizer = Tokenizer()
+
+ def _prepare_cut(cut):
+ # Each cut only contains one supervision
+ assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
+ text = cut.supervisions[0].text
+ cut.supervisions[0].normalized_text = text
+ tokens = tokenizer.texts_to_tokens([text])[0]
+ cut.tokens = tokens
+ return cut
+
+ def _filter_cut(cut):
+ text = cut.supervisions[0].text
+ duration = cut.supervisions[0].duration
+ chinese = []
+ english = []
+
+ # only contains chinese and space and alphabets
+ clean_chars = []
+ for x in text:
+ if is_hangul(x):
+ logging.info(f"Delete cut with text containing Korean : {text}")
+ return False
+ if is_japanese(x):
+ logging.info(f"Delete cut with text containing Japanese : {text}")
+ return False
+ if is_chinese(x):
+ chinese.append(x)
+ clean_chars.append(x)
+ if is_alphabet(x):
+ english.append(x)
+ clean_chars.append(x)
+ if x == " ":
+ clean_chars.append(x)
+ if len(english) + len(chinese) == 0:
+ logging.info(f"Delete cut with text has no valid chars : {text}")
+ return False
+
+ words = tokenize_by_CJK_char("".join(clean_chars))
+ for i in range(len(words) - 10):
+ if words[i : i + 10].count(words[i]) == 10:
+ logging.info(f"Delete cut with text with too much repeats : {text}")
+ return False
+ # word speed, 20 - 600 / minute
+ if duration < len(words) / 600 * 60 or duration > len(words) / 20 * 60:
+ logging.info(
+ f"Delete cut with audio text mismatch, duration : {duration}s, words : {len(words)}, text : {text}"
+ )
+ return False
+ return True
+
+ try:
+ cut_set = load_manifest_lazy(input_dir / file_name)
+ cut_set = cut_set.filter(_filter_cut)
+ cut_set = cut_set.map(_prepare_cut)
+ cut_set.to_file(output_dir / file_name)
+ except Exception as e:
+ logging.error(f"Manifest {file_name} failed with error: {e}")
+ raise
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+
+ input_dir = Path(args.source_dir)
+ output_dir = Path(args.dest_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ cut_files = glob.glob(f"{args.source_dir}/emilia_cuts_{args.subset}.*.jsonl.gz")
+
+ with Pool(max_workers=args.jobs) as pool:
+ futures = [
+ pool.submit(
+ prepare_tokens_emilia, filename.split("/")[-1], input_dir, output_dir
+ )
+ for filename in cut_files
+ ]
+ for f in futures:
+ try:
+ f.result()
+ f.done()
+ except Exception as e:
+ logging.error(f"Future failed with error: {e}")
+ logging.info("Processing done.")
diff --git a/egs/zipvoice/local/validate_manifest.py b/egs/zipvoice/local/validate_manifest.py
new file mode 100755
index 000000000..68159ae03
--- /dev/null
+++ b/egs/zipvoice/local/validate_manifest.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script checks the following assumptions of the generated manifest:
+
+- Single supervision per cut
+
+We will add more checks later if needed.
+
+Usage example:
+
+ python3 ./local/validate_manifest.py \
+ ./data/spectrogram/ljspeech_cuts_all.jsonl.gz
+
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+from lhotse import CutSet, load_manifest_lazy
+from lhotse.dataset.speech_synthesis import validate_for_tts
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "manifest",
+ type=Path,
+ help="Path to the manifest file",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ manifest = args.manifest
+ logging.info(f"Validating {manifest}")
+
+ assert manifest.is_file(), f"{manifest} does not exist"
+ cut_set = load_manifest_lazy(manifest)
+ assert isinstance(cut_set, CutSet), type(cut_set)
+
+ validate_for_tts(cut_set)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/zipvoice/zipvoice/checkpoint.py b/egs/zipvoice/zipvoice/checkpoint.py
new file mode 100644
index 000000000..e3acd57dd
--- /dev/null
+++ b/egs/zipvoice/zipvoice/checkpoint.py
@@ -0,0 +1,142 @@
+# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+import torch
+import torch.nn as nn
+from lhotse.dataset.sampling.base import CutSampler
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+
+# use duck typing for LRScheduler since we have different possibilities, see
+# our class LRScheduler.
+LRSchedulerType = object
+
+
+def save_checkpoint(
+ filename: Path,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ model_ema: Optional[nn.Module] = None,
+ params: Optional[Dict[str, Any]] = None,
+ optimizer: Optional[Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ scaler: Optional[GradScaler] = None,
+ sampler: Optional[CutSampler] = None,
+ rank: int = 0,
+) -> None:
+ """Save training information to a file.
+
+ Args:
+ filename:
+ The checkpoint filename.
+ model:
+ The model to be saved. We only save its `state_dict()`.
+ model_avg:
+ The stored model averaged from the start of training.
+ model_ema:
+ The EMA version of model.
+ params:
+ User defined parameters, e.g., epoch, loss.
+ optimizer:
+ The optimizer to be saved. We only save its `state_dict()`.
+ scheduler:
+ The scheduler to be saved. We only save its `state_dict()`.
+ scalar:
+ The GradScaler to be saved. We only save its `state_dict()`.
+ sampler:
+ The sampler used in the labeled training dataset. We only
+ save its `state_dict()`.
+ rank:
+ Used in DDP. We save checkpoint only for the node whose
+ rank is 0.
+ Returns:
+ Return None.
+ """
+ if rank != 0:
+ return
+
+ logging.info(f"Saving checkpoint to {filename}")
+
+ if isinstance(model, DDP):
+ model = model.module
+
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict() if optimizer is not None else None,
+ "scheduler": scheduler.state_dict() if scheduler is not None else None,
+ "grad_scaler": scaler.state_dict() if scaler is not None else None,
+ "sampler": sampler.state_dict() if sampler is not None else None,
+ }
+
+ if model_avg is not None:
+ checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
+ if model_ema is not None:
+ checkpoint["model_ema"] = model_ema.to(torch.float32).state_dict()
+
+ if params:
+ for k, v in params.items():
+ assert k not in checkpoint
+ checkpoint[k] = v
+
+ torch.save(checkpoint, filename)
+
+
+def load_checkpoint(
+ filename: Path,
+ model: Optional[nn.Module] = None,
+ model_avg: Optional[nn.Module] = None,
+ model_ema: Optional[nn.Module] = None,
+ strict: bool = False,
+) -> Dict[str, Any]:
+ logging.info(f"Loading checkpoint from {filename}")
+ checkpoint = torch.load(filename, map_location="cpu")
+
+ if model is not None:
+
+ if next(iter(checkpoint["model"])).startswith("module."):
+ logging.info("Loading checkpoint saved by DDP")
+
+ dst_state_dict = model.state_dict()
+ src_state_dict = checkpoint["model"]
+ for key in dst_state_dict.keys():
+ src_key = "{}.{}".format("module", key)
+ dst_state_dict[key] = src_state_dict.pop(src_key)
+ assert len(src_state_dict) == 0
+ model.load_state_dict(dst_state_dict, strict=strict)
+ else:
+ logging.info("Loading checkpoint")
+ model.load_state_dict(checkpoint["model"], strict=strict)
+
+ checkpoint.pop("model")
+
+ if model_avg is not None and "model_avg" in checkpoint:
+ logging.info("Loading averaged model")
+ model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
+ checkpoint.pop("model_avg")
+
+ if model_ema is not None and "model_ema" in checkpoint:
+ logging.info("Loading ema model")
+ model_ema.load_state_dict(checkpoint["model_ema"], strict=strict)
+ checkpoint.pop("model_ema")
+
+ return checkpoint
diff --git a/egs/zipvoice/zipvoice/feature.py b/egs/zipvoice/zipvoice/feature.py
new file mode 100644
index 000000000..e7d484d10
--- /dev/null
+++ b/egs/zipvoice/zipvoice/feature.py
@@ -0,0 +1,135 @@
+#!/usr/bin/env python3
+# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torchaudio
+from lhotse.features.base import FeatureExtractor, register_extractor
+from lhotse.utils import Seconds, compute_num_frames
+
+
+class MelSpectrogramFeatures(nn.Module):
+ def __init__(
+ self,
+ sampling_rate=24000,
+ n_mels=100,
+ n_fft=1024,
+ hop_length=256,
+ ):
+ super().__init__()
+
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
+ sample_rate=sampling_rate,
+ n_fft=n_fft,
+ hop_length=hop_length,
+ n_mels=n_mels,
+ center=True,
+ power=1,
+ )
+
+ def forward(self, inp):
+ assert len(inp.shape) == 2
+
+ mel = self.mel_spec(inp)
+ logmel = mel.clamp(min=1e-7).log()
+ return logmel
+
+
+@dataclass
+class TorchAudioFbankConfig:
+ sampling_rate: int
+ n_mels: int
+ n_fft: int
+ hop_length: int
+
+
+@register_extractor
+class TorchAudioFbank(FeatureExtractor):
+
+ name = "TorchAudioFbank"
+ config_type = TorchAudioFbankConfig
+
+ def __init__(self, config):
+ super().__init__(config=config)
+
+ def _feature_fn(self, sample):
+ fbank = MelSpectrogramFeatures(
+ sampling_rate=self.config.sampling_rate,
+ n_mels=self.config.n_mels,
+ n_fft=self.config.n_fft,
+ hop_length=self.config.hop_length,
+ )
+
+ return fbank(sample)
+
+ @property
+ def device(self) -> Union[str, torch.device]:
+ return self.config.device
+
+ def feature_dim(self, sampling_rate: int) -> int:
+ return self.config.n_mels
+
+ def extract(
+ self,
+ samples: Union[np.ndarray, torch.Tensor],
+ sampling_rate: int,
+ ) -> Union[np.ndarray, torch.Tensor]:
+ # Check for sampling rate compatibility.
+ expected_sr = self.config.sampling_rate
+ assert sampling_rate == expected_sr, (
+ f"Mismatched sampling rate: extractor expects {expected_sr}, "
+ f"got {sampling_rate}"
+ )
+ is_numpy = False
+ if not isinstance(samples, torch.Tensor):
+ samples = torch.from_numpy(samples)
+ is_numpy = True
+
+ if len(samples.shape) == 1:
+ samples = samples.unsqueeze(0)
+ assert samples.ndim == 2, samples.shape
+ assert samples.shape[0] == 1, samples.shape
+
+ mel = self._feature_fn(samples).squeeze().t()
+
+ assert mel.ndim == 2, mel.shape
+ assert mel.shape[1] == self.config.n_mels, mel.shape
+
+ num_frames = compute_num_frames(
+ samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
+ )
+
+ if mel.shape[0] > num_frames:
+ mel = mel[:num_frames]
+ elif mel.shape[0] < num_frames:
+ mel = mel.unsqueeze(0)
+ mel = torch.nn.functional.pad(
+ mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
+ ).squeeze(0)
+
+ if is_numpy:
+ return mel.cpu().numpy()
+ else:
+ return mel
+
+ @property
+ def frame_shift(self) -> Seconds:
+ return self.config.hop_length / self.config.sampling_rate
diff --git a/egs/zipvoice/zipvoice/generate_averaged_model.py b/egs/zipvoice/zipvoice/generate_averaged_model.py
new file mode 100755
index 000000000..e1b7ca7c6
--- /dev/null
+++ b/egs/zipvoice/zipvoice/generate_averaged_model.py
@@ -0,0 +1,209 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+This script loads checkpoints and averages them.
+
+(1) Average ZipVoice models before distill:
+ python3 ./zipvoice/generate_averaged_model.py \
+ --epoch 11 \
+ --avg 4 \
+ --distill 0 \
+ --token-file data/tokens_emilia.txt \
+ --exp-dir ./zipvoice/exp_zipvoice
+
+ It will generate a file `epoch-11-avg-14.pt` in the given `exp_dir`.
+ You can later load it by `torch.load("epoch-11-avg-4.pt")`.
+
+(2) Average ZipVoice-Distill models (the first stage model):
+
+ python3 ./zipvoice/generate_averaged_model.py \
+ --iter 60000 \
+ --avg 7 \
+ --distill 1 \
+ --token-file data/tokens_emilia.txt \
+ --exp-dir ./zipvoice/exp_zipvoice_distill_1stage
+"""
+
+import argparse
+from pathlib import Path
+
+import torch
+from model import get_distill_model, get_model
+from tokenizer import TokenizerEmilia, TokenizerLibriTTS
+from train_flow import add_model_arguments, get_params
+
+from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
+from icefall.utils import str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=11,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=4,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' or --iter",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipvoice/exp_zipvoice",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--distill",
+ type=str2bool,
+ default=False,
+ help="Whether to use distill model. ",
+ )
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="emilia",
+ choices=["emilia", "libritts"],
+ help="The used training dataset for the model to inference",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ if params.dataset == "emilia":
+ tokenizer = TokenizerEmilia(
+ token_file=params.token_file, token_type=params.token_type
+ )
+ elif params.dataset == "libritts":
+ tokenizer = TokenizerLibriTTS(
+ token_file=params.token_file, token_type=params.token_type
+ )
+
+ params.vocab_size = tokenizer.vocab_size
+ params.pad_id = tokenizer.pad_id
+
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ print("Script started")
+
+ params.device = torch.device("cpu")
+ print(f"Device: {params.device}")
+
+ print("About to create model")
+ if params.distill:
+ model = get_distill_model(params)
+ else:
+ model = get_model(params)
+
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ print(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(params.device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=params.device,
+ ),
+ strict=True,
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ print(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(params.device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=params.device,
+ ),
+ strict=True,
+ )
+ if params.iter > 0:
+ filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
+ else:
+ filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
+ torch.save({"model": model.state_dict()}, filename)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ print(f"Number of model parameters: {num_param}")
+
+ print("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/zipvoice/zipvoice/infer.py b/egs/zipvoice/zipvoice/infer.py
new file mode 100644
index 000000000..2819d3c85
--- /dev/null
+++ b/egs/zipvoice/zipvoice/infer.py
@@ -0,0 +1,586 @@
+#!/usr/bin/env python3
+# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
+# Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script loads checkpoints to generate waveforms.
+This script is supposed to be used with the model trained by yourself.
+If you want to use the pre-trained checkpoints provided by us, please refer to zipvoice_infer.py.
+
+(1) Usage with a pre-trained checkpoint:
+
+ (a) ZipVoice model before distill:
+ python3 zipvoice/infer.py \
+ --checkpoint zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
+ --distill 0 \
+ --token-file "data/tokens_emilia.txt" \
+ --test-list test.tsv \
+ --res-dir results/test \
+ --num-step 16 \
+ --guidance-scale 1
+
+ (b) ZipVoice-Distill:
+ python3 zipvoice/infer.py \
+ --checkpoint zipvoice/exp_zipvoice_distill/checkpoint-2000.pt \
+ --distill 1 \
+ --token-file "data/tokens_emilia.txt" \
+ --test-list test.tsv \
+ --res-dir results/test_distill \
+ --num-step 8 \
+ --guidance-scale 3
+
+(2) Usage with a directory of checkpoints (may requires checkpoint averaging):
+
+ (a) ZipVoice model before distill:
+ python3 flow_match/infer.py \
+ --exp-dir zipvoice/exp_zipvoice \
+ --epoch 11 \
+ --avg 4 \
+ --distill 0 \
+ --token-file "data/tokens_emilia.txt" \
+ --test-list test.tsv \
+ --res-dir results \
+ --num-step 16 \
+ --guidance-scale 1
+
+ (b) ZipVoice-Distill:
+ python3 flow_match/infer.py \
+ --exp-dir zipvoice/exp_zipvoice_distill/ \
+ --iter 2000 \
+ --avg 0 \
+ --distill 1 \
+ --token-file "data/tokens_emilia.txt" \
+ --test-list test.tsv \
+ --res-dir results \
+ --num-step 8 \
+ --guidance-scale 3
+"""
+
+
+import argparse
+import datetime as dt
+import logging
+import os
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import soundfile as sf
+import torch
+import torch.nn as nn
+import torchaudio
+from checkpoint import load_checkpoint
+from feature import TorchAudioFbank, TorchAudioFbankConfig
+from lhotse.utils import fix_random_seed
+from model import get_distill_model, get_model
+from tokenizer import TokenizerEmilia, TokenizerLibriTTS
+from train_flow import add_model_arguments, get_params
+from vocos import Vocos
+
+from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
+from icefall.utils import AttributeDict, setup_logger, str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ default=None,
+ help="The checkpoint for inference. "
+ "If it is None, will use checkpoints under exp_dir",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipvoice/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=0,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=4,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' or '--iter', avg=0 means no avg",
+ )
+
+ parser.add_argument(
+ "--vocoder-path",
+ type=str,
+ default=None,
+ help="The local vocos vocoder path, downloaded from huggingface, "
+ "will download the vocodoer from huggingface if it is None.",
+ )
+
+ parser.add_argument(
+ "--distill",
+ type=str2bool,
+ default=False,
+ help="Whether it is a distilled TTS model.",
+ )
+
+ parser.add_argument(
+ "--test-list",
+ type=str,
+ default=None,
+ help="The list of prompt speech, prompt_transcription, "
+ "and text to synthesize in the format of "
+ "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
+ )
+
+ parser.add_argument(
+ "--res-dir",
+ type=str,
+ default="results",
+ help="Path name of the generated wavs dir",
+ )
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="emilia",
+ choices=["emilia", "libritts"],
+ help="The used training dataset for the model to inference",
+ )
+
+ parser.add_argument(
+ "--guidance-scale",
+ type=float,
+ default=1.0,
+ help="The scale of classifier-free guidance during inference.",
+ )
+
+ parser.add_argument(
+ "--num-step",
+ type=int,
+ default=16,
+ help="The number of sampling steps.",
+ )
+
+ parser.add_argument(
+ "--feat-scale",
+ type=float,
+ default=0.1,
+ help="The scale factor of fbank feature",
+ )
+
+ parser.add_argument(
+ "--speed",
+ type=float,
+ default=1.0,
+ help="Control speech speed, 1.0 means normal, >1.0 means speed up",
+ )
+
+ parser.add_argument(
+ "--t-shift",
+ type=float,
+ default=0.5,
+ help="Shift t to smaller ones if t_shift < 1.0",
+ )
+
+ parser.add_argument(
+ "--target-rms",
+ type=float,
+ default=0.1,
+ help="Target speech normalization rms value",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=666,
+ help="Random seed",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_vocoder(vocos_local_path: Optional[str] = None):
+ if vocos_local_path:
+ vocos_local_path = "model/huggingface/vocos-mel-24khz/"
+ vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
+ state_dict = torch.load(
+ f"{vocos_local_path}/pytorch_model.bin",
+ weights_only=True,
+ map_location="cpu",
+ )
+ vocoder.load_state_dict(state_dict)
+ else:
+ vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
+ return vocoder
+
+
+def generate_sentence(
+ save_path: str,
+ prompt_text: str,
+ prompt_wav: str,
+ text: str,
+ model: nn.Module,
+ vocoder: nn.Module,
+ tokenizer: TokenizerEmilia,
+ feature_extractor: TorchAudioFbank,
+ device: torch.device,
+ num_step: int = 16,
+ guidance_scale: float = 1.0,
+ speed: float = 1.0,
+ t_shift: float = 0.5,
+ target_rms: float = 0.1,
+ feat_scale: float = 0.1,
+ sampling_rate: int = 24000,
+):
+ """
+ Generate waveform of a text based on a given prompt
+ waveform and its transcription.
+
+ Args:
+ save_path (str): Path to save the generated wav.
+ prompt_text (str): Transcription of the prompt wav.
+ prompt_wav (str): Path to the prompt wav file.
+ text (str): Text to be synthesized into a waveform.
+ model (nn.Module): The model used for generation.
+ vocoder (nn.Module): The vocoder used to convert features to waveforms.
+ tokenizer (TokenizerEmilia): The tokenizer used to convert text to tokens.
+ feature_extractor (TorchAudioFbank): The feature extractor used to
+ extract acoustic features.
+ device (torch.device): The device on which computations are performed.
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
+ guidance_scale (float, optional): Scale for classifier-free guidance.
+ Defaults to 1.0.
+ speed (float, optional): Speed control. Defaults to 1.0.
+ t_shift (float, optional): Time shift. Defaults to 0.5.
+ target_rms (float, optional): Target RMS for waveform normalization.
+ Defaults to 0.1.
+ feat_scale (float, optional): Scale for features.
+ Defaults to 0.1.
+ sampling_rate (int, optional): Sampling rate for the waveform.
+ Defaults to 24000.
+ Returns:
+ metrics (dict): Dictionary containing time and real-time
+ factor metrics for processing.
+ """
+ # Convert text to tokens
+ tokens = tokenizer.texts_to_token_ids([text])
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
+
+ # Load and preprocess prompt wav
+ prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
+ if prompt_rms < target_rms:
+ prompt_wav = prompt_wav * target_rms / prompt_rms
+
+ if prompt_sampling_rate != sampling_rate:
+ resampler = torchaudio.transforms.Resample(
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
+ )
+ prompt_wav = resampler(prompt_wav)
+
+ # Extract features from prompt wav
+ prompt_features = feature_extractor.extract(
+ prompt_wav, sampling_rate=sampling_rate
+ ).to(device)
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
+
+ # Start timing
+ start_t = dt.datetime.now()
+
+ # Generate features
+ (
+ pred_features,
+ pred_features_lens,
+ pred_prompt_features,
+ pred_prompt_features_lens,
+ ) = model.sample(
+ tokens=tokens,
+ prompt_tokens=prompt_tokens,
+ prompt_features=prompt_features,
+ prompt_features_lens=prompt_features_lens,
+ speed=speed,
+ t_shift=t_shift,
+ duration="predict",
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ )
+
+ # Postprocess predicted features
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
+
+ # Start vocoder processing
+ start_vocoder_t = dt.datetime.now()
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
+
+ # Calculate processing times and real-time factors
+ t = (dt.datetime.now() - start_t).total_seconds()
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
+ wav_seconds = wav.shape[-1] / sampling_rate
+ rtf = t / wav_seconds
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
+ rtf_vocoder = t_vocoder / wav_seconds
+ metrics = {
+ "t": t,
+ "t_no_vocoder": t_no_vocoder,
+ "t_vocoder": t_vocoder,
+ "wav_seconds": wav_seconds,
+ "rtf": rtf,
+ "rtf_no_vocoder": rtf_no_vocoder,
+ "rtf_vocoder": rtf_vocoder,
+ }
+
+ # Adjust wav volume if necessary
+ if prompt_rms < target_rms:
+ wav = wav * prompt_rms / target_rms
+ wav = wav[0].cpu().numpy()
+ sf.write(save_path, wav, sampling_rate)
+
+ return metrics
+
+
+def generate(
+ params: AttributeDict,
+ model: nn.Module,
+ vocoder: nn.Module,
+ tokenizer: TokenizerEmilia,
+):
+ total_t = []
+ total_t_no_vocoder = []
+ total_t_vocoder = []
+ total_wav_seconds = []
+
+ config = TorchAudioFbankConfig(
+ sampling_rate=params.sampling_rate,
+ n_mels=100,
+ n_fft=1024,
+ hop_length=256,
+ )
+ feature_extractor = TorchAudioFbank(config)
+
+ with open(params.test_list, "r") as fr:
+ lines = fr.readlines()
+
+ for i, line in enumerate(lines):
+ wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
+ save_path = f"{params.wav_dir}/{wav_name}.wav"
+ metrics = generate_sentence(
+ save_path=save_path,
+ prompt_text=prompt_text,
+ prompt_wav=prompt_wav,
+ text=text,
+ model=model,
+ vocoder=vocoder,
+ tokenizer=tokenizer,
+ feature_extractor=feature_extractor,
+ device=params.device,
+ num_step=params.num_step,
+ guidance_scale=params.guidance_scale,
+ speed=params.speed,
+ t_shift=params.t_shift,
+ target_rms=params.target_rms,
+ feat_scale=params.feat_scale,
+ sampling_rate=params.sampling_rate,
+ )
+ print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
+ total_t.append(metrics["t"])
+ total_t_no_vocoder.append(metrics["t_no_vocoder"])
+ total_t_vocoder.append(metrics["t_vocoder"])
+ total_wav_seconds.append(metrics["wav_seconds"])
+
+ print(f"Average RTF: " f"{np.sum(total_t)/np.sum(total_wav_seconds):.4f}")
+ print(
+ f"Average RTF w/o vocoder: "
+ f"{np.sum(total_t_no_vocoder)/np.sum(total_wav_seconds):.4f}"
+ )
+ print(
+ f"Average RTF vocoder: "
+ f"{np.sum(total_t_vocoder)/np.sum(total_wav_seconds):.4f}"
+ )
+
+
+@torch.inference_mode()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ if params.iter > 0:
+ params.suffix = (
+ f"wavs-iter-{params.iter}-avg"
+ f"-{params.avg}-step-{params.num_step}-scale-{params.guidance_scale}"
+ )
+ elif params.epoch > 0:
+ params.suffix = (
+ f"wavs-epoch-{params.epoch}-avg"
+ f"-{params.avg}-step-{params.num_step}-scale-{params.guidance_scale}"
+ )
+ else:
+ params.suffix = "wavs"
+
+ setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
+ logging.info("Decoding started")
+
+ if torch.cuda.is_available():
+ params.device = torch.device("cuda", 0)
+ else:
+ params.device = torch.device("cpu")
+
+ logging.info(f"Device: {params.device}")
+
+ if params.dataset == "emilia":
+ tokenizer = TokenizerEmilia(
+ token_file=params.token_file, token_type=params.token_type
+ )
+ elif params.dataset == "libritts":
+ tokenizer = TokenizerLibriTTS(
+ token_file=params.token_file, token_type=params.token_type
+ )
+
+ params.vocab_size = tokenizer.vocab_size
+ params.pad_id = tokenizer.pad_id
+
+ logging.info(params)
+ fix_random_seed(params.seed)
+
+ logging.info("About to create model")
+ if params.distill:
+ model = get_distill_model(params)
+ else:
+ model = get_model(params)
+
+ if params.checkpoint:
+ load_checkpoint(params.checkpoint, model, strict=True)
+ else:
+ if params.avg == 0:
+ if params.iter > 0:
+ load_checkpoint(
+ f"{params.exp_dir}/checkpoint-{params.iter}.pt", model, strict=True
+ )
+ else:
+ load_checkpoint(
+ f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True
+ )
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(params.device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=params.device,
+ ),
+ strict=True,
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(params.device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=params.device,
+ ),
+ strict=True,
+ )
+
+ model = model.to(params.device)
+ model.eval()
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ vocoder = get_vocoder(params.vocoder_path)
+ vocoder = vocoder.to(params.device)
+ vocoder.eval()
+ num_param = sum([p.numel() for p in vocoder.parameters()])
+ logging.info(f"Number of vocoder parameters: {num_param}")
+
+ params.wav_dir = f"{params.res_dir}/{params.suffix}"
+ os.makedirs(params.wav_dir, exist_ok=True)
+
+ assert (
+ params.test_list is not None
+ ), "Please provide --test-list for speech synthesize."
+ generate(
+ params=params,
+ model=model,
+ vocoder=vocoder,
+ tokenizer=tokenizer,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ main()
diff --git a/egs/zipvoice/zipvoice/model.py b/egs/zipvoice/zipvoice/model.py
new file mode 100644
index 000000000..25c7973b2
--- /dev/null
+++ b/egs/zipvoice/zipvoice/model.py
@@ -0,0 +1,578 @@
+# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
+# Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+from scaling import ScheduledFloat
+from solver import EulerSolver
+from torch.nn.parallel import DistributedDataParallel as DDP
+from utils import (
+ AttributeDict,
+ condition_time_mask,
+ get_tokens_index,
+ make_pad_mask,
+ pad_labels,
+ prepare_avg_tokens_durations,
+ to_int_tuple,
+)
+from zipformer import TTSZipformer
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ """Get the normal TTS model."""
+
+ fm_decoder = get_fm_decoder_model(params)
+ text_encoder = get_text_encoder_model(params)
+
+ model = TtsModel(
+ fm_decoder=fm_decoder,
+ text_encoder=text_encoder,
+ text_embed_dim=params.text_embed_dim,
+ feat_dim=params.feat_dim,
+ vocab_size=params.vocab_size,
+ pad_id=params.pad_id,
+ )
+ return model
+
+
+def get_distill_model(params: AttributeDict) -> nn.Module:
+ """Get the distillation TTS model."""
+
+ fm_decoder = get_fm_decoder_model(params, distill=True)
+ text_encoder = get_text_encoder_model(params)
+
+ model = DistillTTSModelTrainWrapper(
+ fm_decoder=fm_decoder,
+ text_encoder=text_encoder,
+ text_embed_dim=params.text_embed_dim,
+ feat_dim=params.feat_dim,
+ vocab_size=params.vocab_size,
+ pad_id=params.pad_id,
+ )
+ return model
+
+
+def get_fm_decoder_model(params: AttributeDict, distill: bool = False) -> nn.Module:
+ """Get the Zipformer-based FM decoder model."""
+
+ encoder = TTSZipformer(
+ in_dim=params.feat_dim * 3,
+ out_dim=params.feat_dim,
+ downsampling_factor=to_int_tuple(params.fm_decoder_downsampling_factor),
+ num_encoder_layers=to_int_tuple(params.fm_decoder_num_layers),
+ cnn_module_kernel=to_int_tuple(params.fm_decoder_cnn_module_kernel),
+ encoder_dim=params.fm_decoder_dim,
+ feedforward_dim=params.fm_decoder_feedforward_dim,
+ num_heads=params.fm_decoder_num_heads,
+ query_head_dim=params.query_head_dim,
+ pos_head_dim=params.pos_head_dim,
+ value_head_dim=params.value_head_dim,
+ pos_dim=params.pos_dim,
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ use_time_embed=True,
+ time_embed_dim=192,
+ use_guidance_scale_embed=distill,
+ )
+ return encoder
+
+
+def get_text_encoder_model(params: AttributeDict) -> nn.Module:
+ """Get the Zipformer-based text encoder model."""
+
+ encoder = TTSZipformer(
+ in_dim=params.text_embed_dim,
+ out_dim=params.feat_dim,
+ downsampling_factor=to_int_tuple(params.text_encoder_downsampling_factor),
+ num_encoder_layers=to_int_tuple(params.text_encoder_num_layers),
+ cnn_module_kernel=to_int_tuple(params.text_encoder_cnn_module_kernel),
+ encoder_dim=params.text_encoder_dim,
+ feedforward_dim=params.text_encoder_feedforward_dim,
+ num_heads=params.text_encoder_num_heads,
+ query_head_dim=params.query_head_dim,
+ pos_head_dim=params.pos_head_dim,
+ value_head_dim=params.value_head_dim,
+ pos_dim=params.pos_dim,
+ dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
+ warmup_batches=4000.0,
+ use_time_embed=False,
+ )
+ return encoder
+
+
+class TtsModel(nn.Module):
+ """The normal TTS model."""
+
+ def __init__(
+ self,
+ fm_decoder: nn.Module,
+ text_encoder: nn.Module,
+ text_embed_dim: int,
+ feat_dim: int,
+ vocab_size: int,
+ pad_id: int = 0,
+ ):
+ """
+ Args:
+ fm_decoder: the flow-matching encoder model, inputs are the
+ input condition embeddings and noisy acoustic features,
+ outputs are better acoustic features.
+ text_encoder: the text encoder model. input are text
+ embeddings, output are contextualized text embeddings.
+ text_embed_dim: dimension of text embedding.
+ feat_dim: dimension of acoustic features.
+ vocab_size: vocabulary size.
+ pad_id: padding id.
+ """
+ super().__init__()
+
+ self.feat_dim = feat_dim
+ self.text_embed_dim = text_embed_dim
+ self.pad_id = pad_id
+
+ self.fm_decoder = fm_decoder
+
+ self.text_encoder = text_encoder
+
+ self.embed = nn.Embedding(vocab_size, text_embed_dim)
+
+ self.distill = False
+
+ def forward_fm_decoder(
+ self,
+ t: torch.Tensor,
+ xt: torch.Tensor,
+ text_condition: torch.Tensor,
+ speech_condition: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ guidance_scale: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Compute velocity.
+ Args:
+ t: A tensor of shape (N, 1, 1) or a tensor of a float,
+ in the range of (0, 1).
+ xt: the input of the current timestep, including condition
+ embeddings and noisy acoustic features.
+ text_condition: the text condition embeddings, with the
+ shape (batch, seq_len, emb_dim).
+ speech_condition: the speech condition embeddings, with the
+ shape (batch, seq_len, emb_dim).
+ padding_mask: The mask for padding, True means masked
+ position, with the shape (N, T).
+ guidance_scale: The guidance scale in classifier-free guidance,
+ which is a tensor of shape (N, 1, 1) or a tensor of a float.
+
+ Returns:
+ predicted velocity, with the shape (batch, seq_len, emb_dim).
+ """
+ assert t.dim() in (0, 3)
+ # Handle t with the shape (N, 1, 1):
+ # squeeze the last dimension if it's size is 1.
+ while t.dim() > 1 and t.size(-1) == 1:
+ t = t.squeeze(-1)
+ if guidance_scale is not None:
+ while guidance_scale.dim() > 1 and guidance_scale.size(-1) == 1:
+ guidance_scale = guidance_scale.squeeze(-1)
+ # Handle t with a single value: expand to the size of batch size.
+ if t.dim() == 0:
+ t = t.repeat(xt.shape[0])
+ if guidance_scale is not None and guidance_scale.dim() == 0:
+ guidance_scale = guidance_scale.repeat(xt.shape[0])
+
+ xt = torch.cat([xt, text_condition, speech_condition], dim=2)
+ vt = self.fm_decoder(
+ x=xt, t=t, padding_mask=padding_mask, guidance_scale=guidance_scale
+ )
+ return vt
+
+ def forward_text_embed(
+ self,
+ tokens: List[List[int]],
+ ):
+ """
+ Get the text embeddings.
+ Args:
+ tokens: a list of list of token ids.
+ Returns:
+ embed: the text embeddings, shape (batch, seq_len, emb_dim).
+ tokens_lens: the length of each token sequence, shape (batch,).
+ """
+ device = (
+ self.device if isinstance(self, DDP) else next(self.parameters()).device
+ )
+ tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S)
+ embed = self.embed(tokens_padded) # (B, S, C)
+ tokens_lens = torch.tensor(
+ [len(token) for token in tokens], dtype=torch.int64, device=device
+ )
+ tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S)
+
+ embed = self.text_encoder(
+ x=embed, t=None, padding_mask=tokens_padding_mask
+ ) # (B, S, C)
+ return embed, tokens_lens
+
+ def forward_text_condition(
+ self,
+ embed: torch.Tensor,
+ tokens_lens: torch.Tensor,
+ features_lens: torch.Tensor,
+ ):
+ """
+ Get the text condition with the same length of the acoustic feature.
+ Args:
+ embed: the text embeddings, shape (batch, token_seq_len, emb_dim).
+ tokens_lens: the length of each token sequence, shape (batch,).
+ features_lens: the length of each acoustic feature sequence,
+ shape (batch,).
+ Returns:
+ text_condition: the text condition, shape
+ (batch, feature_seq_len, emb_dim).
+ padding_mask: the padding mask of text condition, shape
+ (batch, feature_seq_len).
+ """
+
+ num_frames = int(features_lens.max())
+
+ padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T)
+
+ tokens_durations = prepare_avg_tokens_durations(features_lens, tokens_lens)
+
+ tokens_index = get_tokens_index(tokens_durations, num_frames).to(
+ embed.device
+ ) # (B, T)
+
+ text_condition = torch.gather(
+ embed,
+ dim=1,
+ index=tokens_index.unsqueeze(-1).expand(
+ embed.size(0), num_frames, embed.size(-1)
+ ),
+ ) # (B, T, F)
+ return text_condition, padding_mask
+
+ def forward_text_train(
+ self,
+ tokens: List[List[int]],
+ features_lens: torch.Tensor,
+ ):
+ """
+ Process text for training, given text tokens and real feature lengths.
+ """
+ embed, tokens_lens = self.forward_text_embed(tokens)
+ text_condition, padding_mask = self.forward_text_condition(
+ embed, tokens_lens, features_lens
+ )
+ return (
+ text_condition,
+ padding_mask,
+ )
+
+ def forward_text_inference_gt_duration(
+ self,
+ tokens: List[List[int]],
+ features_lens: torch.Tensor,
+ prompt_tokens: List[List[int]],
+ prompt_features_lens: torch.Tensor,
+ ):
+ """
+ Process text for inference, given text tokens, real feature lengths and prompts.
+ """
+ tokens = [
+ prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
+ ]
+ features_lens = prompt_features_lens + features_lens
+ embed, tokens_lens = self.forward_text_embed(tokens)
+ text_condition, padding_mask = self.forward_text_condition(
+ embed, tokens_lens, features_lens
+ )
+ return text_condition, padding_mask
+
+ def forward_text_inference_ratio_duration(
+ self,
+ tokens: List[List[int]],
+ prompt_tokens: List[List[int]],
+ prompt_features_lens: torch.Tensor,
+ speed: float,
+ ):
+ """
+ Process text for inference, given text tokens and prompts,
+ feature lengths are predicted with the ratio of token numbers.
+ """
+ device = (
+ self.device if isinstance(self, DDP) else next(self.parameters()).device
+ )
+
+ cat_tokens = [
+ prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
+ ]
+
+ prompt_tokens_lens = torch.tensor(
+ [len(token) for token in prompt_tokens], dtype=torch.int64, device=device
+ )
+
+ cat_embed, cat_tokens_lens = self.forward_text_embed(cat_tokens)
+
+ features_lens = torch.ceil(
+ (prompt_features_lens / prompt_tokens_lens * cat_tokens_lens / speed)
+ ).to(dtype=torch.int64)
+
+ text_condition, padding_mask = self.forward_text_condition(
+ cat_embed, cat_tokens_lens, features_lens
+ )
+ return text_condition, padding_mask
+
+ def forward(
+ self,
+ tokens: List[List[int]],
+ features: torch.Tensor,
+ features_lens: torch.Tensor,
+ noise: torch.Tensor,
+ t: torch.Tensor,
+ condition_drop_ratio: float = 0.0,
+ ) -> torch.Tensor:
+ """Forward pass of the model for training.
+ Args:
+ tokens: a list of list of token ids.
+ features: the acoustic features, with the shape (batch, seq_len, feat_dim).
+ features_lens: the length of each acoustic feature sequence, shape (batch,).
+ noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
+ t: the time step, with the shape (batch, 1, 1).
+ condition_drop_ratio: the ratio of dropped text condition.
+ Returns:
+ fm_loss: the flow-matching loss.
+ """
+
+ (text_condition, padding_mask,) = self.forward_text_train(
+ tokens=tokens,
+ features_lens=features_lens,
+ )
+
+ speech_condition_mask = condition_time_mask(
+ features_lens=features_lens,
+ mask_percent=(0.7, 1.0),
+ max_len=features.size(1),
+ )
+ speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
+
+ if condition_drop_ratio > 0.0:
+ drop_mask = (
+ torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
+ > condition_drop_ratio
+ )
+ text_condition = text_condition * drop_mask
+
+ xt = features * t + noise * (1 - t)
+ ut = features - noise # (B, T, F)
+
+ vt = self.forward_fm_decoder(
+ t=t,
+ xt=xt,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ )
+
+ loss_mask = speech_condition_mask & (~padding_mask)
+ fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
+
+ return fm_loss
+
+ def sample(
+ self,
+ tokens: List[List[int]],
+ prompt_tokens: List[List[int]],
+ prompt_features: torch.Tensor,
+ prompt_features_lens: torch.Tensor,
+ features_lens: Optional[torch.Tensor] = None,
+ speed: float = 1.0,
+ t_shift: float = 1.0,
+ duration: str = "predict",
+ num_step: int = 5,
+ guidance_scale: float = 0.5,
+ ) -> torch.Tensor:
+ """
+ Generate acoustic features, given text tokens, prompts feature
+ and prompt transcription's text tokens.
+ Args:
+ tokens: a list of list of text tokens.
+ prompt_tokens: a list of list of prompt tokens.
+ prompt_features: the prompt feature with the shape
+ (batch_size, seq_len, feat_dim).
+ prompt_features_lens: the length of each prompt feature,
+ with the shape (batch_size,).
+ features_lens: the length of the predicted eature, with the
+ shape (batch_size,). It is used only when duration is "real".
+ duration: "real" or "predict". If "real", the predicted
+ feature length is given by features_lens.
+ num_step: the number of steps to use in the ODE solver.
+ guidance_scale: the guidance scale for classifier-free guidance.
+ distill: whether to use the distillation model for sampling.
+ """
+
+ assert duration in ["real", "predict"]
+
+ if duration == "predict":
+ (
+ text_condition,
+ padding_mask,
+ ) = self.forward_text_inference_ratio_duration(
+ tokens=tokens,
+ prompt_tokens=prompt_tokens,
+ prompt_features_lens=prompt_features_lens,
+ speed=speed,
+ )
+ else:
+ assert features_lens is not None
+ text_condition, padding_mask = self.forward_text_inference_gt_duration(
+ tokens=tokens,
+ features_lens=features_lens,
+ prompt_tokens=prompt_tokens,
+ prompt_features_lens=prompt_features_lens,
+ )
+ batch_size, num_frames, _ = text_condition.shape
+
+ speech_condition = torch.nn.functional.pad(
+ prompt_features, (0, 0, 0, num_frames - prompt_features.size(1))
+ ) # (B, T, F)
+
+ # False means speech condition positions.
+ speech_condition_mask = make_pad_mask(prompt_features_lens, num_frames)
+ speech_condition = torch.where(
+ speech_condition_mask.unsqueeze(-1),
+ torch.zeros_like(speech_condition),
+ speech_condition,
+ )
+
+ x0 = torch.randn(
+ batch_size, num_frames, self.feat_dim, device=text_condition.device
+ )
+ solver = EulerSolver(self, distill=self.distill, func_name="forward_fm_decoder")
+
+ x1 = solver.sample(
+ x=x0,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ t_shift=t_shift,
+ )
+ x1_wo_prompt_lens = (~padding_mask).sum(-1) - prompt_features_lens
+ x1_prompt = torch.zeros(
+ x1.size(0), prompt_features_lens.max(), x1.size(2), device=x1.device
+ )
+ x1_wo_prompt = torch.zeros(
+ x1.size(0), x1_wo_prompt_lens.max(), x1.size(2), device=x1.device
+ )
+ for i in range(x1.size(0)):
+ x1_wo_prompt[i, : x1_wo_prompt_lens[i], :] = x1[
+ i,
+ prompt_features_lens[i] : prompt_features_lens[i]
+ + x1_wo_prompt_lens[i],
+ ]
+ x1_prompt[i, : prompt_features_lens[i], :] = x1[
+ i, : prompt_features_lens[i]
+ ]
+
+ return x1_wo_prompt, x1_wo_prompt_lens, x1_prompt, prompt_features_lens
+
+ def sample_intermediate(
+ self,
+ tokens: List[List[int]],
+ features: torch.Tensor,
+ features_lens: torch.Tensor,
+ noise: torch.Tensor,
+ speech_condition_mask: torch.Tensor,
+ t_start: torch.Tensor,
+ t_end: torch.Tensor,
+ num_step: int = 1,
+ guidance_scale: torch.Tensor = None,
+ ) -> torch.Tensor:
+ """
+ Generate acoustic features in intermediate timesteps.
+ Args:
+ tokens: List of list of token ids.
+ features: The acoustic features, with the shape (batch, seq_len, feat_dim).
+ features_lens: The length of each acoustic feature sequence,
+ with the shape (batch,).
+ noise: The initial noise, with the shape (batch, seq_len, feat_dim).
+ speech_condition_mask: The mask for speech condition, True means
+ non-condition positions, with the shape (batch, seq_len).
+ t_start: The start timestep, with the shape (batch, 1, 1).
+ t_end: The end timestep, with the shape (batch, 1, 1).
+ num_step: The number of steps for sampling.
+ guidance_scale: The scale for classifier-free guidance inference,
+ with the shape (batch, 1, 1).
+ distill: Whether to use distillation model.
+ """
+ (text_condition, padding_mask,) = self.forward_text_train(
+ tokens=tokens,
+ features_lens=features_lens,
+ )
+
+ speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
+
+ solver = EulerSolver(self, distill=self.distill, func_name="forward_fm_decoder")
+
+ x_t_end = solver.sample(
+ x=noise,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ t_start=t_start,
+ t_end=t_end,
+ )
+ x_t_end_lens = (~padding_mask).sum(-1)
+ return x_t_end, x_t_end_lens
+
+
+class DistillTTSModelTrainWrapper(TtsModel):
+ """Wrapper for training the distilled TTS model."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.distill = True
+
+ def forward(
+ self,
+ tokens: List[List[int]],
+ features: torch.Tensor,
+ features_lens: torch.Tensor,
+ noise: torch.Tensor,
+ speech_condition_mask: torch.Tensor,
+ t_start: torch.Tensor,
+ t_end: torch.Tensor,
+ num_step: int = 1,
+ guidance_scale: torch.Tensor = None,
+ ) -> torch.Tensor:
+
+ return self.sample_intermediate(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=noise,
+ speech_condition_mask=speech_condition_mask,
+ t_start=t_start,
+ t_end=t_end,
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ )
diff --git a/egs/zipvoice/zipvoice/optim.py b/egs/zipvoice/zipvoice/optim.py
new file mode 100644
index 000000000..daf17556a
--- /dev/null
+++ b/egs/zipvoice/zipvoice/optim.py
@@ -0,0 +1,1256 @@
+# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
+#
+# See ../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import logging
+import random
+from collections import defaultdict
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from lhotse.utils import fix_random_seed
+from torch import Tensor
+from torch.optim import Optimizer
+
+
+class BatchedOptimizer(Optimizer):
+ """
+ This class adds to class Optimizer the capability to optimize parameters in batches:
+ it will stack the parameters and their grads for you so the optimizer can work
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
+ as it reduces the number of kernels launched in the optimizer.
+
+ Args:
+ params:
+ """
+
+ def __init__(self, params, defaults):
+ super(BatchedOptimizer, self).__init__(params, defaults)
+
+ @contextlib.contextmanager
+ def batched_params(self, param_group, group_params_names):
+ """
+ This function returns (technically, yields) a list of
+ of tuples (p, state), where
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
+ that share the same shape, and its gradient is also stacked;
+ `state` is the state corresponding to this batch of parameters
+ (it will be physically located in the "state" for one of the real
+ parameters, the last one that has any particular shape and dtype).
+
+ This function is decorated as a context manager so that it can
+ write parameters back to their "real" locations.
+
+ The idea is, instead of doing:
+
+ for p in group["params"]:
+ state = self.state[p]
+ ...
+
+ you can do:
+
+ with self.batched_params(group["params"]) as batches:
+ for p, state, p_names in batches:
+ ...
+
+
+ Args:
+ group: a parameter group, which is a list of parameters; should be
+ one of self.param_groups.
+ group_params_names: name for each parameter in group,
+ which is List[str].
+ """
+ batches = defaultdict(
+ list
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
+ batches_names = defaultdict(
+ list
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
+
+ assert len(param_group) == len(group_params_names)
+ for p, named_p in zip(param_group, group_params_names):
+ key = (str(p.dtype), *p.shape)
+ batches[key].append(p)
+ batches_names[key].append(named_p)
+
+ batches_names_keys = list(batches_names.keys())
+ sorted_idx = sorted(
+ range(len(batches_names)), key=lambda i: batches_names_keys[i]
+ )
+ batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
+
+ stacked_params_dict = dict()
+
+ # turn batches into a list, in deterministic order.
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
+ # one for each batch in `batches`.
+ tuples = []
+
+ for batch, batch_names in zip(batches, batches_names):
+ p = batch[0]
+ # we arbitrarily store the state in the
+ # state corresponding to the 1st parameter in the
+ # group. class Optimizer will take care of saving/loading state.
+ state = self.state[p]
+ p_stacked = torch.stack(batch)
+ grad = torch.stack(
+ [torch.zeros_like(p) if p.grad is None else p.grad for p in batch]
+ )
+ p_stacked.grad = grad
+ stacked_params_dict[key] = p_stacked
+ tuples.append((p_stacked, state, batch_names))
+
+ yield tuples # <-- calling code will do the actual optimization here!
+
+ for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
+ for i, p in enumerate(batch): # batch is list of Parameter
+ p.copy_(stacked_params[i])
+
+
+def basic_step(group, p, state, grad):
+ # computes basic Adam update using beta2 (dividing by gradient stddev) only. no momentum yet.
+ lr = group["lr"]
+ if p.numel() == p.shape[0]:
+ lr = lr * group["scalar_lr_scale"]
+ beta2 = group["betas"][1]
+ eps = group["eps"]
+ # p shape: (batch_size,) or (batch_size, 1, [1,..])
+ try:
+ exp_avg_sq = state[
+ "exp_avg_sq"
+ ] # shape: (batch_size,) or (batch_size, 1, [1,..])
+ except KeyError:
+ exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
+ state["exp_avg_sq"] = exp_avg_sq
+
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+
+ # bias_correction2 is like in Adam.
+ # slower update at the start will help stability anyway.
+ bias_correction2 = 1 - beta2 ** (state["step"] + 1)
+ if bias_correction2 < 0.99:
+ # note: not in-place.
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
+ denom = exp_avg_sq.sqrt().add_(eps)
+
+ return -lr * grad / denom
+
+
+def scaling_step(group, p, state, grad):
+ delta = basic_step(group, p, state, grad)
+ if p.numel() == p.shape[0]:
+ return delta # there is no scaling for scalar parameters. (p.shape[0] is the batch of parameters.)
+
+ step = state["step"]
+ size_update_period = group["size_update_period"]
+
+ try:
+ param_rms = state["param_rms"]
+ scale_grads = state["scale_grads"]
+ scale_exp_avg_sq = state["scale_exp_avg_sq"]
+ except KeyError:
+ # we know p.ndim > 1 because we'd have returned above if not, so don't worry
+ # about the speial case of dim=[] that pytorch treats inconsistently.
+ param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
+ param_rms = param_rms.to(torch.float)
+ scale_exp_avg_sq = torch.zeros_like(param_rms)
+ scale_grads = torch.zeros(
+ size_update_period, *param_rms.shape, dtype=torch.float, device=p.device
+ )
+ state["param_rms"] = param_rms
+ state["scale_grads"] = scale_grads
+ state["scale_exp_avg_sq"] = scale_exp_avg_sq
+
+ # on every step, update the gradient w.r.t. the scale of the parameter, we
+ # store these as a batch and periodically update the size (for speed only, to
+ # avoid too many operations).
+ scale_grads[step % size_update_period] = (p * grad).sum(
+ dim=list(range(1, p.ndim)), keepdim=True
+ )
+
+ # periodically recompute the value of param_rms.
+ if step % size_update_period == size_update_period - 1:
+ param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
+
+ param_min_rms = group["param_min_rms"]
+
+ # scale the step size by param_rms. This is the most important "scaling" part of
+ # ScaledAdam
+ delta *= param_rms.clamp(min=param_min_rms)
+
+ if step % size_update_period == size_update_period - 1 and step > 0:
+ # This block updates the size of parameter by adding a step ("delta") value in
+ # the direction of either shrinking or growing it.
+ beta2 = group["betas"][1]
+ size_lr = group["lr"] * group["scalar_lr_scale"]
+ param_max_rms = group["param_max_rms"]
+ eps = group["eps"]
+ batch_size = p.shape[0]
+ # correct beta2 for the size update period: we will have
+ # faster decay at this level.
+ beta2_corr = beta2**size_update_period
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
+ (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
+ alpha=1 - beta2_corr,
+ ) # shape is (batch_size, 1, 1, ...)
+
+ # The 1st time we reach here is when size_step == 1.
+ size_step = (step + 1) // size_update_period
+ bias_correction2 = 1 - beta2_corr**size_step
+
+ denom = scale_exp_avg_sq.sqrt() + eps
+
+ scale_step = (
+ -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
+ )
+
+ is_too_small = param_rms < param_min_rms
+
+ # when the param gets too small, just don't shrink it any further.
+ scale_step.masked_fill_(is_too_small, 0.0)
+
+ # The following may help prevent instability: don't allow the scale step to be too large in
+ # either direction.
+ scale_step.clamp_(min=-0.1, max=0.1)
+
+ # and ensure the parameter rms after update never exceeds param_max_rms.
+ # We have to look at the trained model for parameters at or around the
+ # param_max_rms, because sometimes they can indicate a problem with the
+ # topology or settings.
+ scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
+
+ delta.add_(p * scale_step)
+
+ return delta
+
+
+def momentum_step(group, p, state, grad):
+ delta = scaling_step(group, p, state, grad)
+ beta1 = group["betas"][0]
+ try:
+ stored_delta = state["delta"]
+ except KeyError:
+ stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
+ state["delta"] = stored_delta
+ stored_delta.mul_(beta1)
+ stored_delta.add_(delta, alpha=(1 - beta1))
+ # we don't bother doing the "bias correction" part of Adam for beta1 because this is just
+ # an edge effect that affects the first 10 or so batches; and the effect of not doing it
+ # is just to do a slower update for the first few batches, which will help stability.
+ return stored_delta
+
+
+class ScaledAdam(BatchedOptimizer):
+ """
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
+ proportional to the norm of that parameter; and also learn the scale of the parameter,
+ in log space, subject to upper and lower limits (as if we had factored each parameter as
+ param = underlying_param * log_scale.exp())
+
+
+ Args:
+ params: The parameters or param_groups to optimize (like other Optimizer subclasses)
+ Unlike common optimizers, which accept model.parameters() or groups of parameters(),
+ this optimizer could accept model.named_parameters() or groups of named_parameters().
+ See comments of function _get_names_of_parameters for its 4 possible cases.
+ lr: The learning rate. We will typically use a learning rate schedule that starts
+ at 0.03 and decreases over time, i.e. much higher than other common
+ optimizers.
+ clipping_scale: (e.g. 2.0)
+ A scale for gradient-clipping: if specified, the normalized gradients
+ over the whole model will be clipped to have 2-norm equal to
+ `clipping_scale` times the median 2-norm over the most recent period
+ of `clipping_update_period` minibatches. By "normalized gradients",
+ we mean after multiplying by the rms parameter value for this tensor
+ [for non-scalars]; this is appropriate because our update is scaled
+ by this quantity.
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
+ Must satisfy 0 < beta <= beta2 < 1.
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the scale of each parameter tensor and scalar parameters of the mode..
+ If each parameter were decomposed
+ as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
+ would be a the scaling factor on the learning rate of p_scale.
+ eps: A general-purpose epsilon to prevent division by zero
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
+ parameter tensor to be >= this value)
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
+ parameter tensor to be <= this value)
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
+ model has any parameters with numel() == 1).
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
+ of the parameter tensor. This is provided to save a little time
+ in the update.
+ clipping_update_period: if clipping_scale is specified, this is the period
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=3e-02,
+ clipping_scale=None,
+ betas=(0.9, 0.98),
+ scalar_lr_scale=0.1,
+ eps=1.0e-08,
+ param_min_rms=1.0e-05,
+ param_max_rms=3.0,
+ scalar_max=10.0,
+ size_update_period=4,
+ clipping_update_period=100,
+ ):
+
+ defaults = dict(
+ lr=lr,
+ clipping_scale=clipping_scale,
+ betas=betas,
+ scalar_lr_scale=scalar_lr_scale,
+ eps=eps,
+ param_min_rms=param_min_rms,
+ param_max_rms=param_max_rms,
+ scalar_max=scalar_max,
+ size_update_period=size_update_period,
+ clipping_update_period=clipping_update_period,
+ )
+
+ # If params only contains parameters or group of parameters,
+ # i.e when parameter names are not given,
+ # this flag will be set to False in funciton _get_names_of_parameters.
+ self.show_dominant_parameters = True
+ param_groups, parameters_names = self._get_names_of_parameters(params)
+ super(ScaledAdam, self).__init__(param_groups, defaults)
+ assert len(self.param_groups) == len(parameters_names)
+ self.parameters_names = parameters_names
+
+ def _get_names_of_parameters(
+ self, params_or_named_params
+ ) -> Tuple[List[Dict], List[List[str]]]:
+ """
+ Args:
+ params_or_named_params: according to the way ScaledAdam is initialized in train.py,
+ this argument could be one of following 4 cases,
+ case 1, a generator of parameter, e.g.:
+ optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0)
+
+ case 2, a list of parameter groups with different config, e.g.:
+ model_param_groups = [
+ {'params': model.encoder.parameters(), 'lr': 0.05},
+ {'params': model.decoder.parameters(), 'lr': 0.01},
+ {'params': model.joiner.parameters(), 'lr': 0.03},
+ ]
+ optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0)
+
+ case 3, a generator of named_parameter, e.g.:
+ optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0)
+
+ case 4, a list of named_parameter groups with different config, e.g.:
+ model_named_param_groups = [
+ {'named_params': model.encoder.named_parameters(), 'lr': 0.05},
+ {'named_params': model.decoder.named_parameters(), 'lr': 0.01},
+ {'named_params': model.joiner.named_parameters(), 'lr': 0.03},
+ ]
+ optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0)
+
+ For case 1 and case 2, input params is used to initialize the underlying torch.optimizer.
+ For case 3 and case 4, firstly, names and params are extracted from input named_params,
+ then, these extracted params are used to initialize the underlying torch.optimizer,
+ and these extracted names are mainly used by function
+ `_show_gradient_dominating_parameter`
+
+ Returns:
+ Returns a tuple containing 2 elements:
+ - `param_groups` with type List[Dict], each Dict element is a parameter group.
+ An example of `param_groups` could be:
+ [
+ {'params': `one iterable of Parameter`, 'lr': 0.05},
+ {'params': `another iterable of Parameter`, 'lr': 0.08},
+ {'params': `a third iterable of Parameter`, 'lr': 0.1},
+ ]
+ - `param_gruops_names` with type List[List[str]],
+ each `List[str]` is for a group['params'] in param_groups,
+ and each `str` is the name of a parameter.
+ A dummy name "foo" is related to each parameter,
+ if input are params without names, i.e. case 1 or case 2.
+ """
+ # variable naming convention in this function:
+ # p is short for param.
+ # np is short for named_param.
+ # p_or_np is short for param_or_named_param.
+ # cur is short for current.
+ # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
+ # groups is a List[group]
+
+ iterable_or_groups = list(params_or_named_params)
+ if len(iterable_or_groups) == 0:
+ raise ValueError("optimizer got an empty parameter list")
+
+ # The first value of returned tuple. A list of dicts containing at
+ # least 'params' as a key.
+ param_groups = []
+
+ # The second value of returned tuple,
+ # a List[List[str]], each sub-List is for a group.
+ param_groups_names = []
+
+ if not isinstance(iterable_or_groups[0], dict):
+ # case 1 or case 3,
+ # the input is an iterable of parameter or named parameter.
+ param_iterable_cur_group = []
+ param_names_cur_group = []
+ for p_or_np in iterable_or_groups:
+ if isinstance(p_or_np, tuple):
+ # case 3
+ name, param = p_or_np
+ else:
+ # case 1
+ assert isinstance(p_or_np, torch.Tensor)
+ param = p_or_np
+ # Assign a dummy name as a placeholder
+ name = "foo"
+ self.show_dominant_parameters = False
+ param_iterable_cur_group.append(param)
+ param_names_cur_group.append(name)
+ param_groups.append({"params": param_iterable_cur_group})
+ param_groups_names.append(param_names_cur_group)
+ else:
+ # case 2 or case 4
+ # the input is groups of parameter or named parameter.
+ for cur_group in iterable_or_groups:
+ if "named_params" in cur_group:
+ name_list = [x[0] for x in cur_group["named_params"]]
+ p_list = [x[1] for x in cur_group["named_params"]]
+ del cur_group["named_params"]
+ cur_group["params"] = p_list
+ else:
+ assert "params" in cur_group
+ name_list = ["foo" for _ in cur_group["params"]]
+ param_groups.append(cur_group)
+ param_groups_names.append(name_list)
+
+ return param_groups, param_groups_names
+
+ def __setstate__(self, state):
+ super(ScaledAdam, self).__setstate__(state)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ batch = True
+
+ for group, group_params_names in zip(self.param_groups, self.parameters_names):
+
+ with self.batched_params(group["params"], group_params_names) as batches:
+
+ # batches is list of pairs (stacked_param, state). stacked_param is like
+ # a regular parameter, and will have a .grad, but the 1st dim corresponds to
+ # a stacking dim, it is not a real dim.
+
+ if (
+ len(batches[0][1]) == 0
+ ): # if len(first state) == 0: not yet initialized
+ clipping_scale = 1
+ else:
+ clipping_scale = self._get_clipping_scale(group, batches)
+
+ for p, state, _ in batches:
+ # Perform optimization step.
+ # grad is not going to be None, we handled that when creating the batches.
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError(
+ "ScaledAdam optimizer does not support sparse gradients"
+ )
+
+ try:
+ cur_step = state["step"]
+ except KeyError:
+ state["step"] = 0
+ cur_step = 0
+
+ grad = (
+ p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale)
+ )
+ p += momentum_step(group, p.detach(), state, grad)
+
+ if p.numel() == p.shape[0]: # scalar parameter
+ scalar_max = group["scalar_max"]
+ p.clamp_(min=-scalar_max, max=scalar_max)
+
+ state["step"] = cur_step + 1
+
+ return loss
+
+ def _get_clipping_scale(
+ self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
+ ) -> float:
+ """
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
+ by this amount before applying the rest of the update.
+
+ Args:
+ group: the parameter group, an item in self.param_groups
+ tuples: a list of tuples of (param, state, param_names)
+ where param is a batched set of parameters,
+ with a .grad (1st dim is batch dim)
+ and state is the state-dict where optimization parameters are kept.
+ param_names is a List[str] while each str is name for a parameter
+ in batched set of parameters "param".
+ """
+ assert len(tuples) >= 1
+ clipping_scale = group["clipping_scale"]
+ (first_p, first_state, _) = tuples[0]
+ step = first_state["step"]
+ if clipping_scale is None or step == 0:
+ # no clipping. return early on step == 0 because the other
+ # parameters' state won't have been initialized yet.
+ return 1.0
+ clipping_update_period = group["clipping_update_period"]
+ scalar_lr_scale = group["scalar_lr_scale"]
+
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
+ for (p, state, param_names) in tuples:
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError(
+ "ScaledAdam optimizer does not support sparse gradients"
+ )
+ if p.numel() == p.shape[0]: # a batch of scalars
+ tot_sumsq += (grad**2).sum() * (
+ scalar_lr_scale**2
+ ) # sum() to change shape [1] to []
+ else:
+ tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
+
+ tot_norm = tot_sumsq.sqrt()
+ if "model_norms" not in first_state:
+ first_state["model_norms"] = torch.zeros(
+ clipping_update_period, device=p.device
+ )
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
+
+ irregular_estimate_steps = [
+ i for i in [10, 20, 40] if i < clipping_update_period
+ ]
+ if step % clipping_update_period == 0 or step in irregular_estimate_steps:
+ # Print some stats.
+ # We don't reach here if step == 0 because we would have returned
+ # above.
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
+ if step in irregular_estimate_steps:
+ sorted_norms = sorted_norms[-step:]
+ num_norms = sorted_norms.numel()
+ quartiles = []
+ for n in range(0, 5):
+ index = min(num_norms - 1, (num_norms // 4) * n)
+ quartiles.append(sorted_norms[index].item())
+
+ median = quartiles[2]
+ if median - median != 0:
+ raise RuntimeError("Too many grads were not finite")
+ threshold = clipping_scale * median
+ if step in irregular_estimate_steps:
+ # use larger thresholds on first few steps of estimating threshold,
+ # as norm may be changing rapidly.
+ threshold = threshold * 2.0
+ first_state["model_norm_threshold"] = threshold
+ percent_clipped = (
+ first_state["num_clipped"] * 100.0 / num_norms
+ if "num_clipped" in first_state
+ else 0.0
+ )
+ first_state["num_clipped"] = 0
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
+ logging.warning(
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
+ f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
+ )
+
+ try:
+ model_norm_threshold = first_state["model_norm_threshold"]
+ except KeyError:
+ return 1.0 # threshold has not yet been set.
+
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
+ if ans != ans: # e.g. ans is nan
+ ans = 0.0
+ if ans < 1.0:
+ first_state["num_clipped"] += 1
+ if ans < 0.5:
+ logging.warning(
+ f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
+ )
+ if self.show_dominant_parameters:
+ assert p.shape[0] == len(param_names)
+ self._show_gradient_dominating_parameter(
+ tuples, tot_sumsq, group["scalar_lr_scale"]
+ )
+ self._show_param_with_unusual_grad(tuples)
+
+ if ans == 0.0:
+ for (p, state, param_names) in tuples:
+ p.grad.zero_() # get rid of infinity()
+
+ return ans
+
+ def _show_param_with_unusual_grad(
+ self,
+ tuples: List[Tuple[Tensor, dict, List[str]]],
+ ):
+ """
+ Print information about parameter which has the largest ratio of grad-on-this-batch
+ divided by normal grad size.
+ tuples: a list of tuples of (param, state, param_names)
+ where param is a batched set of parameters,
+ with a .grad (1st dim is batch dim)
+ and state is the state-dict where optimization parameters are kept.
+ param_names is a List[str] while each str is name for a parameter
+ in batched set of parameters "param".
+ """
+ largest_ratio = 0.0
+ largest_name = ""
+ # ratios_names is a list of 3-tuples: (grad_ratio, param_name, tensor)
+ ratios_names = []
+ for (p, state, batch_param_names) in tuples:
+ dims = list(range(1, p.ndim))
+
+ def mean(x):
+ # workaround for bad interface of torch's "mean" for when dims is the empty list.
+ if len(dims) > 0:
+ return x.mean(dim=dims)
+ else:
+ return x
+
+ grad_ratio = (
+ (mean(p.grad**2) / state["exp_avg_sq"].mean(dim=dims))
+ .sqrt()
+ .to("cpu")
+ )
+
+ ratios_names += zip(
+ grad_ratio.tolist(), batch_param_names, p.grad.unbind(dim=0)
+ )
+
+ ratios_names = sorted(ratios_names, reverse=True)
+ ratios_names = ratios_names[:10]
+ ratios_names = [
+ (ratio, name, largest_index(tensor))
+ for (ratio, name, tensor) in ratios_names
+ ]
+
+ logging.debug(
+ f"Parameters with most larger-than-usual grads, with ratios, are: {ratios_names}"
+ )
+
+ def _show_gradient_dominating_parameter(
+ self,
+ tuples: List[Tuple[Tensor, dict, List[str]]],
+ tot_sumsq: Tensor,
+ scalar_lr_scale: float,
+ ):
+ """
+ Show information of parameter which dominates tot_sumsq.
+
+ Args:
+ tuples: a list of tuples of (param, state, param_names)
+ where param is a batched set of parameters,
+ with a .grad (1st dim is batch dim)
+ and state is the state-dict where optimization parameters are kept.
+ param_names is a List[str] while each str is name for a parameter
+ in batched set of parameters "param".
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
+ from tuples, we still pass it to save some time.
+ """
+ all_sumsq_orig = {}
+ for (p, state, batch_param_names) in tuples:
+ # p is a stacked batch parameters.
+ batch_grad = p.grad
+ if p.numel() == p.shape[0]: # a batch of scalars
+ # Dummy values used by following `zip` statement.
+ batch_rms_orig = torch.full(
+ p.shape, scalar_lr_scale, device=batch_grad.device
+ )
+ else:
+ batch_rms_orig = state["param_rms"]
+ batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
+ if batch_grad.ndim > 1:
+ # need to guard it with if-statement because sum() sums over
+ # all dims if dim == ().
+ batch_sumsq_orig = batch_sumsq_orig.sum(
+ dim=list(range(1, batch_grad.ndim))
+ )
+ for name, sumsq_orig, rms, grad in zip(
+ batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
+ ):
+
+ proportion_orig = sumsq_orig / tot_sumsq
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
+
+ sorted_by_proportion = {
+ k: v
+ for k, v in sorted(
+ all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True
+ )
+ }
+ dominant_param_name = next(iter(sorted_by_proportion))
+ (
+ dominant_proportion,
+ dominant_sumsq,
+ dominant_rms,
+ dominant_grad,
+ ) = sorted_by_proportion[dominant_param_name]
+ logging.debug(
+ f"Parameter dominating tot_sumsq {dominant_param_name}"
+ f" with proportion {dominant_proportion:.2f},"
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
+ f"={dominant_sumsq:.3e},"
+ f" grad_sumsq={(dominant_grad**2).sum():.3e},"
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}"
+ )
+
+
+def largest_index(x: Tensor):
+ x = x.contiguous()
+ argmax = x.abs().argmax().item()
+ return [(argmax // x.stride(i)) % x.size(i) for i in range(x.ndim)]
+
+
+class LRScheduler(object):
+ """
+ Base-class for learning rate schedulers where the learning-rate depends on both the
+ batch and the epoch.
+ """
+
+ def __init__(self, optimizer: Optimizer, verbose: bool = False):
+ # Attach optimizer
+ if not isinstance(optimizer, Optimizer):
+ raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
+ self.optimizer = optimizer
+ self.verbose = verbose
+
+ for group in optimizer.param_groups:
+ group.setdefault("base_lr", group["lr"])
+
+ self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
+
+ self.epoch = 0
+ self.batch = 0
+
+ def state_dict(self):
+ """Returns the state of the scheduler as a :class:`dict`.
+
+ It contains an entry for every variable in self.__dict__ which
+ is not the optimizer.
+ """
+ return {
+ # the user might try to override the base_lr, so don't include this in the state.
+ # previously they were included.
+ # "base_lrs": self.base_lrs,
+ "epoch": self.epoch,
+ "batch": self.batch,
+ }
+
+ def load_state_dict(self, state_dict):
+ """Loads the schedulers state.
+
+ Args:
+ state_dict (dict): scheduler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # the things with base_lrs are a work-around for a previous problem
+ # where base_lrs were written with the state dict.
+ base_lrs = self.base_lrs
+ self.__dict__.update(state_dict)
+ self.base_lrs = base_lrs
+
+ def get_last_lr(self) -> List[float]:
+ """Return last computed learning rate by current scheduler. Will be a list of float."""
+ return self._last_lr
+
+ def get_lr(self):
+ # Compute list of learning rates from self.epoch and self.batch and
+ # self.base_lrs; this must be overloaded by the user.
+ # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
+ raise NotImplementedError
+
+ def step_batch(self, batch: Optional[int] = None) -> None:
+ # Step the batch index, or just set it. If `batch` is specified, it
+ # must be the batch index from the start of training, i.e. summed over
+ # all epochs.
+ # You can call this in any order; if you don't provide 'batch', it should
+ # of course be called once per batch.
+ if batch is not None:
+ self.batch = batch
+ else:
+ self.batch = self.batch + 1
+ self._set_lrs()
+
+ def step_epoch(self, epoch: Optional[int] = None):
+ # Step the epoch index, or just set it. If you provide the 'epoch' arg,
+ # you should call this at the start of the epoch; if you don't provide the 'epoch'
+ # arg, you should call it at the end of the epoch.
+ if epoch is not None:
+ self.epoch = epoch
+ else:
+ self.epoch = self.epoch + 1
+ self._set_lrs()
+
+ def _set_lrs(self):
+ values = self.get_lr()
+ assert len(values) == len(self.optimizer.param_groups)
+
+ for i, data in enumerate(zip(self.optimizer.param_groups, values)):
+ param_group, lr = data
+ param_group["lr"] = lr
+ self.print_lr(self.verbose, i, lr)
+ self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
+
+ def print_lr(self, is_verbose, group, lr):
+ """Display the current learning rate."""
+ if is_verbose:
+ logging.warning(
+ f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
+ f" of group {group} to {lr:.4e}."
+ )
+
+
+class Eden(LRScheduler):
+ """
+ Eden scheduler.
+ The basic formula (before warmup) is:
+ lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
+ (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
+ where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
+ and then stays constant at 1.
+
+ If you don't have the concept of epochs, or one epoch takes a very long time,
+ you can replace the notion of 'epoch' with some measure of the amount of data
+ processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to
+ some measure representing "quite a lot of data": say, one fifth or one third
+ of an entire training run, but it doesn't matter much. You could also use
+ Eden2 which has only the notion of batches.
+
+ We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
+
+ Args:
+ optimizer: the optimizer to change the learning rates on
+ lr_batches: the number of batches after which we start significantly
+ decreasing the learning rate, suggest 5000.
+ lr_epochs: the number of epochs after which we start significantly
+ decreasing the learning rate, suggest 6 if you plan to do e.g.
+ 20 to 40 epochs, but may need smaller number if dataset is huge
+ and you will do few epochs.
+ """
+
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ lr_batches: Union[int, float],
+ lr_epochs: Union[int, float],
+ warmup_batches: Union[int, float] = 500.0,
+ warmup_start: float = 0.5,
+ verbose: bool = False,
+ ):
+ super(Eden, self).__init__(optimizer, verbose)
+ self.lr_batches = lr_batches
+ self.lr_epochs = lr_epochs
+ self.warmup_batches = warmup_batches
+
+ assert 0.0 <= warmup_start <= 1.0, warmup_start
+ self.warmup_start = warmup_start
+
+ def get_lr(self):
+ factor = (
+ (self.batch**2 + self.lr_batches**2) / self.lr_batches**2
+ ) ** -0.25 * (
+ ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
+ )
+ warmup_factor = (
+ 1.0
+ if self.batch >= self.warmup_batches
+ else self.warmup_start
+ + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
+ # else 0.5 + 0.5 * (self.batch / self.warmup_batches)
+ )
+
+ return [x * factor * warmup_factor for x in self.base_lrs]
+
+
+class Eden2(LRScheduler):
+ """
+ Eden2 scheduler, simpler than Eden because it does not use the notion of epoch,
+ only batches.
+
+ The basic formula (before warmup) is:
+ lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup
+
+ where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
+ and then stays constant at 1.
+
+
+ E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
+
+ Args:
+ optimizer: the optimizer to change the learning rates on
+ lr_batches: the number of batches after which we start significantly
+ decreasing the learning rate, suggest 5000.
+ """
+
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ lr_batches: Union[int, float],
+ warmup_batches: Union[int, float] = 500.0,
+ warmup_start: float = 0.5,
+ verbose: bool = False,
+ ):
+ super().__init__(optimizer, verbose)
+ self.lr_batches = lr_batches
+ self.warmup_batches = warmup_batches
+
+ assert 0.0 <= warmup_start <= 1.0, warmup_start
+ self.warmup_start = warmup_start
+
+ def get_lr(self):
+ factor = (
+ (self.batch**2 + self.lr_batches**2) / self.lr_batches**2
+ ) ** -0.5
+ warmup_factor = (
+ 1.0
+ if self.batch >= self.warmup_batches
+ else self.warmup_start
+ + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
+ # else 0.5 + 0.5 * (self.batch / self.warmup_batches)
+ )
+
+ return [x * factor * warmup_factor for x in self.base_lrs]
+
+
+class FixedLRScheduler(LRScheduler):
+ """
+ Fixed learning rate scheduler.
+
+ Args:
+ optimizer: the optimizer to change the learning rates on
+ """
+
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ verbose: bool = False,
+ ):
+ super(FixedLRScheduler, self).__init__(optimizer, verbose)
+
+ def get_lr(self):
+
+ return [x for x in self.base_lrs]
+
+
+def _test_eden():
+ m = torch.nn.Linear(100, 100)
+ optim = ScaledAdam(m.parameters(), lr=0.03)
+
+ scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
+
+ for epoch in range(10):
+ scheduler.step_epoch(epoch) # sets epoch to `epoch`
+
+ for step in range(20):
+ x = torch.randn(200, 100).detach()
+ x.requires_grad = True
+ y = m(x)
+ dy = torch.randn(200, 100).detach()
+ f = (y * dy).sum()
+ f.backward()
+
+ optim.step()
+ scheduler.step_batch()
+ optim.zero_grad()
+
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
+ logging.info(f"state dict = {scheduler.state_dict()}")
+
+
+# This is included mostly as a baseline for ScaledAdam.
+class Eve(Optimizer):
+ """
+ Implements Eve algorithm. This is a modified version of AdamW with a special
+ way of setting the weight-decay / shrinkage-factor, which is designed to make the
+ rms of the parameters approach a particular target_rms (default: 0.1). This is
+ for use with networks with 'scaled' versions of modules (see scaling.py), which
+ will be close to invariant to the absolute scale on the parameter matrix.
+
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
+ Eve is unpublished so far.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 3e-4;
+ this value means that the weight would decay significantly after
+ about 3k minibatches. Is not multiplied by learning rate, but
+ is conditional on RMS-value of parameter being > target_rms.
+ target_rms (float, optional): target root-mean-square value of
+ parameters, if they fall below this we will stop applying weight decay.
+
+
+ .. _Adam: A Method for Stochastic Optimization:
+ https://arxiv.org/abs/1412.6980
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.98),
+ eps=1e-8,
+ weight_decay=1e-3,
+ target_rms=0.1,
+ ):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0 <= weight_decay <= 0.1:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0 < target_rms <= 10.0:
+ raise ValueError("Invalid target_rms value: {}".format(target_rms))
+ defaults = dict(
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ target_rms=target_rms,
+ )
+ super(Eve, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(Eve, self).__setstate__(state)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+
+ # Perform optimization step
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError("AdamW does not support sparse gradients")
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state["step"] = 0
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros_like(
+ p, memory_format=torch.preserve_format
+ )
+ # Exponential moving average of squared gradient values
+ state["exp_avg_sq"] = torch.zeros_like(
+ p, memory_format=torch.preserve_format
+ )
+
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
+
+ beta1, beta2 = group["betas"]
+
+ state["step"] += 1
+ bias_correction1 = 1 - beta1 ** state["step"]
+ bias_correction2 = 1 - beta2 ** state["step"]
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+ denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_(
+ group["eps"]
+ )
+
+ step_size = group["lr"] / bias_correction1
+ target_rms = group["target_rms"]
+ weight_decay = group["weight_decay"]
+
+ if p.numel() > 1:
+ # avoid applying this weight-decay on "scaling factors"
+ # (which are scalar).
+ is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5))
+ p.mul_(1 - (weight_decay * is_above_target_rms))
+
+ p.addcdiv_(exp_avg, denom, value=-step_size)
+
+ if random.random() < 0.0005:
+ step = (exp_avg / denom) * step_size
+ logging.info(
+ f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
+ )
+
+ return loss
+
+
+def _test_scaled_adam(hidden_dim: int):
+ import timeit
+
+ from scaling import ScaledLinear
+
+ E = 100
+ B = 4
+ T = 2
+ logging.info("in test_eve_cain")
+ # device = torch.device('cuda')
+ device = torch.device("cpu")
+ dtype = torch.float32
+
+ fix_random_seed(42)
+ # these input_magnitudes and output_magnitudes are to test that
+ # Abel is working as we expect and is able to adjust scales of
+ # different dims differently.
+ input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
+ output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
+
+ for iter in [1, 0]:
+ fix_random_seed(42)
+ Linear = torch.nn.Linear if iter == 0 else ScaledLinear
+
+ m = torch.nn.Sequential(
+ Linear(E, hidden_dim),
+ torch.nn.PReLU(),
+ Linear(hidden_dim, hidden_dim),
+ torch.nn.PReLU(),
+ Linear(hidden_dim, E),
+ ).to(device)
+
+ train_pairs = [
+ (
+ 100.0
+ * torch.randn(B, T, E, device=device, dtype=dtype)
+ * input_magnitudes,
+ torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes,
+ )
+ for _ in range(20)
+ ]
+
+ if iter == 0:
+ optim = Eve(m.parameters(), lr=0.003)
+ elif iter == 1:
+ optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0)
+ scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
+
+ start = timeit.default_timer()
+ avg_loss = 0.0
+ for epoch in range(180):
+ scheduler.step_epoch()
+ # if epoch == 100 and iter in [2,3]:
+ # optim.reset_speedup() # check it doesn't crash.
+
+ # if epoch == 130:
+ # opts = diagnostics.TensorDiagnosticOptions(
+ # 512
+ # ) # allow 4 megabytes per sub-module
+ # diagnostic = diagnostics.attach_diagnostics(m, opts)
+
+ for n, (x, y) in enumerate(train_pairs):
+ y_out = m(x)
+ loss = ((y_out - y) ** 2).mean() * 100.0
+ if epoch == 0 and n == 0:
+ avg_loss = loss.item()
+ else:
+ avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
+ if n == 0 and epoch % 5 == 0:
+ # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
+ # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
+ # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
+ # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
+ # scale1 = '%.2e' % (m[0].weight_scale.exp().item())
+ # scale1b = '%.2e' % (m[0].bias_scale.exp().item())
+ # scale2 = '%.2e' % (m[2].weight_scale.exp().item())
+ # scale2b = '%.2e' % (m[2].bias_scale.exp().item())
+ lr = scheduler.get_last_lr()[0]
+ logging.info(
+ f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
+ ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
+ loss.log().backward()
+ optim.step()
+ optim.zero_grad()
+ scheduler.step_batch()
+
+ # diagnostic.print_diagnostics()
+
+ stop = timeit.default_timer()
+ logging.info(f"Iter={iter}, Time taken: {stop - start}")
+
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
+ # logging.info("state dict = ", scheduler.state_dict())
+ # logging.info("optim state_dict = ", optim.state_dict())
+ logging.info(f"input_magnitudes = {input_magnitudes}")
+ logging.info(f"output_magnitudes = {output_magnitudes}")
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ logging.getLogger().setLevel(logging.INFO)
+ import subprocess
+
+ s = subprocess.check_output(
+ "git status -uno .; git log -1; git diff HEAD .", shell=True
+ )
+ logging.info(s)
+ import sys
+
+ if len(sys.argv) > 1:
+ hidden_dim = int(sys.argv[1])
+ else:
+ hidden_dim = 200
+
+ _test_scaled_adam(hidden_dim)
+ _test_eden()
diff --git a/egs/zipvoice/zipvoice/scaling.py b/egs/zipvoice/zipvoice/scaling.py
new file mode 100644
index 000000000..5211e3a76
--- /dev/null
+++ b/egs/zipvoice/zipvoice/scaling.py
@@ -0,0 +1,1910 @@
+# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import math
+import random
+from typing import Optional, Tuple, Union
+
+import k2
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+custom_bwd = lambda func: torch.amp.custom_bwd(func, device_type="cuda")
+custom_fwd = lambda func: torch.amp.custom_fwd(func, device_type="cuda")
+
+
+def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
+ max_value = torch.max(x, y)
+ diff = torch.abs(x - y)
+ return max_value + torch.log1p(torch.exp(-diff))
+
+
+# RuntimeError: Exporting the operator logaddexp to ONNX opset version
+# 14 is not supported. Please feel free to request support or submit
+# a pull request on PyTorch GitHub.
+#
+# The following function is to solve the above error when exporting
+# models to ONNX via torch.jit.trace()
+def logaddexp(x: Tensor, y: Tensor) -> Tensor:
+ # Caution(fangjun): Put torch.jit.is_scripting() before
+ # torch.onnx.is_in_onnx_export();
+ # otherwise, it will cause errors for torch.jit.script().
+ #
+ # torch.logaddexp() works for both torch.jit.script() and
+ # torch.jit.trace() but it causes errors for ONNX export.
+ #
+ if torch.jit.is_scripting():
+ # Note: We cannot use torch.jit.is_tracing() here as it also
+ # matches torch.onnx.export().
+ return torch.logaddexp(x, y)
+ elif torch.onnx.is_in_onnx_export():
+ return logaddexp_onnx(x, y)
+ else:
+ # for torch.jit.trace()
+ return torch.logaddexp(x, y)
+
+
+class PiecewiseLinear(object):
+ """
+ Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
+ the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y]
+ respectively.
+ """
+
+ def __init__(self, *args):
+ assert len(args) >= 1, len(args)
+ if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
+ self.pairs = list(args[0].pairs)
+ else:
+ self.pairs = [(float(x), float(y)) for x, y in args]
+ for x, y in self.pairs:
+ assert isinstance(x, (float, int)), type(x)
+ assert isinstance(y, (float, int)), type(y)
+
+ for i in range(len(self.pairs) - 1):
+ assert self.pairs[i + 1][0] > self.pairs[i][0], (
+ i,
+ self.pairs[i],
+ self.pairs[i + 1],
+ )
+
+ def __str__(self):
+ # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
+ return f"PiecewiseLinear({str(self.pairs)[1:-1]})"
+
+ def __call__(self, x):
+ if x <= self.pairs[0][0]:
+ return self.pairs[0][1]
+ elif x >= self.pairs[-1][0]:
+ return self.pairs[-1][1]
+ else:
+ cur_x, cur_y = self.pairs[0]
+ for i in range(1, len(self.pairs)):
+ next_x, next_y = self.pairs[i]
+ if x >= cur_x and x <= next_x:
+ return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
+ cur_x, cur_y = next_x, next_y
+ assert False
+
+ def __mul__(self, alpha):
+ return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
+
+ def __add__(self, x):
+ if isinstance(x, (float, int)):
+ return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
+ s, x = self.get_common_basis(x)
+ return PiecewiseLinear(
+ *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]
+ )
+
+ def max(self, x):
+ if isinstance(x, (float, int)):
+ x = PiecewiseLinear((0, x))
+ s, x = self.get_common_basis(x, include_crossings=True)
+ return PiecewiseLinear(
+ *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
+ )
+
+ def min(self, x):
+ if isinstance(x, float) or isinstance(x, int):
+ x = PiecewiseLinear((0, x))
+ s, x = self.get_common_basis(x, include_crossings=True)
+ return PiecewiseLinear(
+ *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
+ )
+
+ def __eq__(self, other):
+ return self.pairs == other.pairs
+
+ def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False):
+ """
+ Returns (self_mod, p_mod) which are equivalent piecewise linear
+ functions to self and p, but with the same x values.
+
+ p: the other piecewise linear function
+ include_crossings: if true, include in the x values positions
+ where the functions indicate by this and p crosss.
+ """
+ assert isinstance(p, PiecewiseLinear), type(p)
+
+ # get sorted x-values without repetition.
+ x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
+ y_vals1 = [self(x) for x in x_vals]
+ y_vals2 = [p(x) for x in x_vals]
+
+ if include_crossings:
+ extra_x_vals = []
+ for i in range(len(x_vals) - 1):
+ if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]):
+ # if the two lines in this subsegment potentially cross each other..
+ diff_cur = abs(y_vals1[i] - y_vals2[i])
+ diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
+ # `pos`, between 0 and 1, gives the relative x position,
+ # with 0 being x_vals[i] and 1 being x_vals[i+1].
+ pos = diff_cur / (diff_cur + diff_next)
+ extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
+ extra_x_vals.append(extra_x_val)
+ if len(extra_x_vals) > 0:
+ x_vals = sorted(set(x_vals + extra_x_vals))
+ y_vals1 = [self(x) for x in x_vals]
+ y_vals2 = [p(x) for x in x_vals]
+ return (
+ PiecewiseLinear(*zip(x_vals, y_vals1)),
+ PiecewiseLinear(*zip(x_vals, y_vals2)),
+ )
+
+
+class ScheduledFloat(torch.nn.Module):
+ """
+ This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
+ it does not have a working forward() function. You are supposed to cast it to float, as
+ in, float(parent_module.whatever), and use it as something like a dropout prob.
+
+ It is a floating point value whose value changes depending on the batch count of the
+ training loop. It is a piecewise linear function where you specify the (x,y) pairs
+ in sorted order on x; x corresponds to the batch index. For batch-index values before the
+ first x or after the last x, we just use the first or last y value.
+
+ Example:
+ self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
+
+ `default` is used when self.batch_count is not set or not in training mode or in
+ torch.jit scripting mode.
+ """
+
+ def __init__(self, *args, default: float = 0.0):
+ super().__init__()
+ # self.batch_count and self.name will be written to in the training loop.
+ self.batch_count = None
+ self.name = None
+ self.default = default
+ self.schedule = PiecewiseLinear(*args)
+
+ def extra_repr(self) -> str:
+ return (
+ f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
+ )
+
+ def __float__(self):
+ batch_count = self.batch_count
+ if (
+ batch_count is None
+ or not self.training
+ or torch.jit.is_scripting()
+ or torch.jit.is_tracing()
+ ):
+ return float(self.default)
+ else:
+ ans = self.schedule(self.batch_count)
+ if random.random() < 0.0002:
+ logging.debug(
+ f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}"
+ )
+ return ans
+
+ def __add__(self, x):
+ if isinstance(x, float) or isinstance(x, int):
+ return ScheduledFloat(self.schedule + x, default=self.default)
+ else:
+ return ScheduledFloat(
+ self.schedule + x.schedule, default=self.default + x.default
+ )
+
+ def max(self, x):
+ if isinstance(x, float) or isinstance(x, int):
+ return ScheduledFloat(self.schedule.max(x), default=self.default)
+ else:
+ return ScheduledFloat(
+ self.schedule.max(x.schedule), default=max(self.default, x.default)
+ )
+
+
+FloatLike = Union[float, ScheduledFloat]
+
+
+def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
+ """
+ A randomized way of casting a floating point value to half precision.
+ """
+ if x.dtype == torch.float16:
+ return x
+ x_abs = x.abs()
+ is_too_small = x_abs < min_abs
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
+ # for those elements].
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
+
+
+class CutoffEstimator:
+ """
+ Estimates cutoffs of an arbitrary numerical quantity such that a specified
+ proportion of items will be above the cutoff on average.
+
+ p is the proportion of items that should be above the cutoff.
+ """
+
+ def __init__(self, p: float):
+ self.p = p
+ # total count of items
+ self.count = 0
+ # total count of items that were above the cutoff
+ self.count_above = 0
+ # initial cutoff value
+ self.cutoff = 0
+
+ def __call__(self, x: float) -> bool:
+ """
+ Returns true if x is above the cutoff.
+ """
+ ans = x > self.cutoff
+ self.count += 1
+ if ans:
+ self.count_above += 1
+ cur_p = self.count_above / self.count
+ delta_p = cur_p - self.p
+ if (delta_p > 0) == ans:
+ q = abs(delta_p)
+ self.cutoff = x * q + self.cutoff * (1 - q)
+ return ans
+
+
+class SoftmaxFunction(torch.autograd.Function):
+ """
+ Tries to handle half-precision derivatives in a randomized way that should
+ be more accurate for training than the default behavior.
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor, dim: int):
+ ans = x.softmax(dim=dim)
+ # if x dtype is float16, x.softmax() returns a float32 because
+ # (presumably) that op does not support float16, and autocast
+ # is enabled.
+ if torch.is_autocast_enabled():
+ ans = ans.to(torch.float16)
+ ctx.save_for_backward(ans)
+ ctx.x_dtype = x.dtype
+ ctx.dim = dim
+ return ans
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor):
+ (ans,) = ctx.saved_tensors
+ with torch.amp.autocast("cuda", enabled=False):
+ ans_grad = ans_grad.to(torch.float32)
+ ans = ans.to(torch.float32)
+ x_grad = ans_grad * ans
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
+ return x_grad, None
+
+
+def softmax(x: Tensor, dim: int):
+ if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x.softmax(dim=dim)
+
+ return SoftmaxFunction.apply(x, dim)
+
+
+class MaxEigLimiterFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ coeffs: Tensor,
+ direction: Tensor,
+ channel_dim: int,
+ grad_scale: float,
+ ) -> Tensor:
+ ctx.channel_dim = channel_dim
+ ctx.grad_scale = grad_scale
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad, *args):
+ with torch.enable_grad():
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
+ x_orig.requires_grad = True
+ num_channels = x_orig.shape[ctx.channel_dim]
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
+ new_direction.requires_grad = False
+ x = x - x.mean(dim=0)
+ x_var = (x**2).mean()
+ x_residual = x - coeffs * new_direction
+ x_residual_var = (x_residual**2).mean()
+ # `variance_proportion` is the proportion of the variance accounted for
+ # by the top eigen-direction. This is to be minimized.
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
+ variance_proportion.backward()
+ x_orig_grad = x_orig.grad
+ x_extra_grad = (
+ x_orig.grad
+ * ctx.grad_scale
+ * x_grad.norm()
+ / (x_orig_grad.norm() + 1.0e-20)
+ )
+ return x_grad + x_extra_grad.detach(), None, None, None, None
+
+
+class BiasNormFunction(torch.autograd.Function):
+ # This computes:
+ # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
+ # return x * scales
+ # (after unsqueezing the bias), but it does it in a memory-efficient way so that
+ # it can just store the returned value (chances are, this will also be needed for
+ # some other reason, related to the next operation, so we can save memory).
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ bias: Tensor,
+ log_scale: Tensor,
+ channel_dim: int,
+ store_output_for_backprop: bool,
+ ) -> Tensor:
+ assert bias.ndim == 1
+ if channel_dim < 0:
+ channel_dim = channel_dim + x.ndim
+ ctx.store_output_for_backprop = store_output_for_backprop
+ ctx.channel_dim = channel_dim
+ for _ in range(channel_dim + 1, x.ndim):
+ bias = bias.unsqueeze(-1)
+ scales = (
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
+ ) * log_scale.exp()
+ ans = x * scales
+ ctx.save_for_backward(
+ ans.detach() if store_output_for_backprop else x,
+ scales.detach(),
+ bias.detach(),
+ log_scale.detach(),
+ )
+ return ans
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor) -> Tensor:
+ ans_or_x, scales, bias, log_scale = ctx.saved_tensors
+ if ctx.store_output_for_backprop:
+ x = ans_or_x / scales
+ else:
+ x = ans_or_x
+ x = x.detach()
+ x.requires_grad = True
+ bias.requires_grad = True
+ log_scale.requires_grad = True
+ with torch.enable_grad():
+ # recompute scales from x, bias and log_scale.
+ scales = (
+ torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5
+ ) * log_scale.exp()
+ ans = x * scales
+ ans.backward(gradient=ans_grad)
+ return x.grad, bias.grad.flatten(), log_scale.grad, None, None
+
+
+class BiasNorm(torch.nn.Module):
+ """
+ This is intended to be a simpler, and hopefully cheaper, replacement for
+ LayerNorm. The observation this is based on, is that Transformer-type
+ networks, especially with pre-norm, sometimes seem to set one of the
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
+ the LayerNorm because the output magnitude is then not strongly dependent
+ on the other (useful) features. Presumably the weight and bias of the
+ LayerNorm are required to allow it to do this.
+
+ Instead, we give the BiasNorm a trainable bias that it can use when
+ computing the scale for normalization. We also give it a (scalar)
+ trainable scale on the output.
+
+
+ Args:
+ num_channels: the number of channels, e.g. 512.
+ channel_dim: the axis/dimension corresponding to the channel,
+ interpreted as an offset from the input's ndim if negative.
+ This is NOT the num_channels; it should typically be one of
+ {-2, -1, 0, 1, 2, 3}.
+ log_scale: the initial log-scale that we multiply the output by; this
+ is learnable.
+ log_scale_min: FloatLike, minimum allowed value of log_scale
+ log_scale_max: FloatLike, maximum allowed value of log_scale
+ store_output_for_backprop: only possibly affects memory use; recommend
+ to set to True if you think the output of this module is more likely
+ than the input of this module to be required to be stored for the
+ backprop.
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ channel_dim: int = -1, # CAUTION: see documentation.
+ log_scale: float = 1.0,
+ log_scale_min: float = -1.5,
+ log_scale_max: float = 1.5,
+ store_output_for_backprop: bool = False,
+ ) -> None:
+ super(BiasNorm, self).__init__()
+ self.num_channels = num_channels
+ self.channel_dim = channel_dim
+ self.log_scale = nn.Parameter(torch.tensor(log_scale))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+
+ self.log_scale_min = log_scale_min
+ self.log_scale_max = log_scale_max
+
+ self.store_output_for_backprop = store_output_for_backprop
+
+ def forward(self, x: Tensor) -> Tensor:
+ assert x.shape[self.channel_dim] == self.num_channels
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ channel_dim = self.channel_dim
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ bias = self.bias
+ for _ in range(channel_dim + 1, x.ndim):
+ bias = bias.unsqueeze(-1)
+ scales = (
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
+ ) * self.log_scale.exp()
+ return x * scales
+
+ log_scale = limit_param_value(
+ self.log_scale,
+ min=float(self.log_scale_min),
+ max=float(self.log_scale_max),
+ training=self.training,
+ )
+
+ return BiasNormFunction.apply(
+ x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop
+ )
+
+
+def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
+ """
+ Behaves like a constructor of a modified version of nn.Linear
+ that gives an easy way to set the default initial parameter scale.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+ ans = nn.Linear(*args, **kwargs)
+ with torch.no_grad():
+ ans.weight[:] *= initial_scale
+ if ans.bias is not None:
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
+ return ans
+
+
+def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d:
+ """
+ Behaves like a constructor of a modified version of nn.Conv1d
+ that gives an easy way to set the default initial parameter scale.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+ ans = nn.Conv1d(*args, **kwargs)
+ with torch.no_grad():
+ ans.weight[:] *= initial_scale
+ if ans.bias is not None:
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
+ return ans
+
+
+def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d:
+ """
+ Behaves like a constructor of a modified version of nn.Conv2d
+ that gives an easy way to set the default initial parameter scale.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False, but:
+ NO PADDING-RELATED ARGS.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+ ans = nn.Conv2d(*args, **kwargs)
+ with torch.no_grad():
+ ans.weight[:] *= initial_scale
+ if ans.bias is not None:
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
+ return ans
+
+
+class ChunkCausalDepthwiseConv1d(torch.nn.Module):
+ """
+ Behaves like a depthwise 1d convolution, except that it is causal in
+ a chunkwise way, as if we had a block-triangular attention mask.
+ The chunk size is provided at test time (it should probably be
+ kept in sync with the attention mask).
+
+ This has a little more than twice the parameters of a conventional
+ depthwise conv1d module: we implement it by having one
+ depthwise convolution, of half the width, that is causal (via
+ right-padding); and one depthwise convolution that is applied only
+ within chunks, that we multiply by a scaling factor which depends
+ on the position within the chunk.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ initial_scale: float = 1.0,
+ bias: bool = True,
+ ):
+ super().__init__()
+ assert kernel_size % 2 == 1
+
+ half_kernel_size = (kernel_size + 1) // 2
+ # will pad manually, on one side.
+ self.causal_conv = nn.Conv1d(
+ in_channels=channels,
+ out_channels=channels,
+ groups=channels,
+ kernel_size=half_kernel_size,
+ padding=0,
+ bias=True,
+ )
+
+ self.chunkwise_conv = nn.Conv1d(
+ in_channels=channels,
+ out_channels=channels,
+ groups=channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ bias=bias,
+ )
+
+ # first row is correction factors added to the scale near the left edge of the chunk,
+ # second row is correction factors added to the scale near the right edge of the chunk,
+ # both of these are added to a default scale of 1.0.
+ self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size))
+ self.kernel_size = kernel_size
+
+ with torch.no_grad():
+ self.causal_conv.weight[:] *= initial_scale
+ self.chunkwise_conv.weight[:] *= initial_scale
+ if bias:
+ torch.nn.init.uniform_(
+ self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale
+ )
+
+ def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
+ """
+ Forward function. Args:
+ x: a Tensor of shape (batch_size, channels, seq_len)
+ chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
+ """
+ (batch_size, num_channels, seq_len) = x.shape
+
+ # half_kernel_size = self.kernel_size + 1 // 2
+ # left_pad is half_kernel_size - 1 where half_kernel_size is the size used
+ # in the causal conv. It's the amount by which we must pad on the left,
+ # to make the convolution causal.
+ left_pad = self.kernel_size // 2
+
+ if chunk_size < 0 or chunk_size > seq_len:
+ chunk_size = seq_len
+ right_pad = -seq_len % chunk_size
+
+ x = torch.nn.functional.pad(x, (left_pad, right_pad))
+
+ x_causal = self.causal_conv(x[..., : left_pad + seq_len])
+ assert x_causal.shape == (batch_size, num_channels, seq_len)
+
+ x_chunk = x[..., left_pad:]
+ num_chunks = x_chunk.shape[2] // chunk_size
+ x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size)
+ x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(
+ batch_size * num_chunks, num_channels, chunk_size
+ )
+ x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
+
+ chunk_scale = self._get_chunk_scale(chunk_size)
+
+ x_chunk = x_chunk * chunk_scale
+ x_chunk = x_chunk.reshape(
+ batch_size, num_chunks, num_channels, chunk_size
+ ).permute(0, 2, 1, 3)
+ x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[
+ ..., :seq_len
+ ]
+
+ return x_chunk + x_causal
+
+ def _get_chunk_scale(self, chunk_size: int):
+ """Returns tensor of shape (num_channels, chunk_size) that will be used to
+ scale the output of self.chunkwise_conv."""
+ left_edge = self.chunkwise_conv_scale[0]
+ right_edge = self.chunkwise_conv_scale[1]
+ if chunk_size < self.kernel_size:
+ left_edge = left_edge[:, :chunk_size]
+ right_edge = right_edge[:, -chunk_size:]
+ else:
+ t = chunk_size - self.kernel_size
+ channels = left_edge.shape[0]
+ pad = torch.zeros(
+ channels, t, device=left_edge.device, dtype=left_edge.dtype
+ )
+ left_edge = torch.cat((left_edge, pad), dim=-1)
+ right_edge = torch.cat((pad, right_edge), dim=-1)
+ return 1.0 + (left_edge + right_edge)
+
+ def streaming_forward(
+ self,
+ x: Tensor,
+ cache: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """Streaming Forward function.
+
+ Args:
+ x: a Tensor of shape (batch_size, channels, seq_len)
+ cache: cached left context of shape (batch_size, channels, left_pad)
+ """
+ (batch_size, num_channels, seq_len) = x.shape
+
+ # left_pad is half_kernel_size - 1 where half_kernel_size is the size used
+ # in the causal conv. It's the amount by which we must pad on the left,
+ # to make the convolution causal.
+ left_pad = self.kernel_size // 2
+
+ # Pad cache
+ assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad)
+ x = torch.cat([cache, x], dim=2)
+ # Update cache
+ cache = x[..., -left_pad:]
+
+ x_causal = self.causal_conv(x)
+ assert x_causal.shape == (batch_size, num_channels, seq_len)
+
+ x_chunk = x[..., left_pad:]
+ x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
+
+ chunk_scale = self._get_chunk_scale(chunk_size=seq_len)
+ x_chunk = x_chunk * chunk_scale
+
+ return x_chunk + x_causal, cache
+
+
+class BalancerFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ min_mean: float,
+ max_mean: float,
+ min_rms: float,
+ max_rms: float,
+ grad_scale: float,
+ channel_dim: int,
+ ) -> Tensor:
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ ctx.channel_dim = channel_dim
+ ctx.save_for_backward(x)
+ ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim)
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
+ (x,) = ctx.saved_tensors
+ (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
+
+ try:
+ with torch.enable_grad():
+ with torch.amp.autocast("cuda", enabled=False):
+ x = x.to(torch.float32)
+ x = x.detach()
+ x.requires_grad = True
+ mean_dims = [i for i in range(x.ndim) if i != channel_dim]
+ uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True)
+ mean = x.mean(dim=mean_dims, keepdim=True)
+ stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
+ rms = uncentered_var.clamp(min=1.0e-20).sqrt()
+
+ m = mean / stddev
+ # part of loss that relates to mean / stddev
+ m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
+
+ # put a much larger scale on the RMS-max-limit loss, so that if both it and the
+ # m_loss are violated we fix the RMS loss first.
+ rms_clamped = rms.clamp(min=min_rms, max=max_rms)
+ r_loss = (rms_clamped / rms).log().abs()
+
+ loss = m_loss + r_loss
+
+ loss.backward(gradient=torch.ones_like(loss))
+ loss_grad = x.grad
+ loss_grad_rms = (
+ (loss_grad**2)
+ .mean(dim=mean_dims, keepdim=True)
+ .sqrt()
+ .clamp(min=1.0e-20)
+ )
+
+ loss_grad = loss_grad * (grad_scale / loss_grad_rms)
+
+ x_grad_float = x_grad.to(torch.float32)
+ # scale each element of loss_grad by the absolute value of the corresponding
+ # element of x_grad, which we view as a noisy estimate of its magnitude for that
+ # (frame and dimension). later we can consider factored versions.
+ x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
+ x_grad = x_grad_mod.to(x_grad.dtype)
+ except Exception as e:
+ logging.info(
+ f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue."
+ )
+
+ return x_grad, None, None, None, None, None, None
+
+
+class Balancer(torch.nn.Module):
+ """
+ Modifies the backpropped derivatives of a function to try to encourage, for
+ each channel, that it is positive at least a proportion `threshold` of the
+ time. It does this by multiplying negative derivative values by up to
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
+ interpolated from 1 at the threshold to those extremal values when none
+ of the inputs are positive.
+
+ Args:
+ num_channels: the number of channels
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
+ min_positive: the minimum, per channel, of the proportion of the time
+ that (x > 0), below which we start to modify the derivatives.
+ max_positive: the maximum, per channel, of the proportion of the time
+ that (x > 0), above which we start to modify the derivatives.
+ scale_gain_factor: determines the 'gain' with which we increase the
+ change in gradient once the constraints on min_abs and max_abs
+ are violated.
+ min_abs: the minimum average-absolute-value difference from the mean
+ value per channel, which we allow, before we start to modify
+ the derivatives to prevent this.
+ max_abs: the maximum average-absolute-value difference from the mean
+ value per channel, which we allow, before we start to modify
+ the derivatives to prevent this.
+ prob: determines the minimum probability with which we modify the
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
+ on each forward(). This is done randomly to prevent all layers
+ from doing it at the same time.
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ channel_dim: int,
+ min_positive: FloatLike = 0.05,
+ max_positive: FloatLike = 0.95,
+ min_abs: FloatLike = 0.2,
+ max_abs: FloatLike = 100.0,
+ grad_scale: FloatLike = 0.04,
+ prob: Optional[FloatLike] = None,
+ ):
+ super().__init__()
+
+ if prob is None:
+ prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
+ self.prob = prob
+ # 5% of the time we will return and do nothing because memory usage is
+ # too high.
+ self.mem_cutoff = CutoffEstimator(0.05)
+
+ # actually self.num_channels is no longer needed except for an assertion.
+ self.num_channels = num_channels
+ self.channel_dim = channel_dim
+ self.min_positive = min_positive
+ self.max_positive = max_positive
+ self.min_abs = min_abs
+ self.max_abs = max_abs
+ self.grad_scale = grad_scale
+
+ def forward(self, x: Tensor) -> Tensor:
+ if (
+ torch.jit.is_scripting()
+ or not x.requires_grad
+ or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))
+ ):
+ return _no_op(x)
+
+ prob = float(self.prob)
+ if random.random() < prob:
+ # The following inner-functions convert from the way we historically specified
+ # these limitations, as limits on the absolute value and the proportion of positive
+ # values, to limits on the RMS value and the (mean / stddev).
+ def _abs_to_rms(x):
+ # for normally distributed data, if the expected absolute value is x, the
+ # expected rms value will be sqrt(pi/2) * x.
+ return 1.25331413732 * x
+
+ def _proportion_positive_to_mean(x):
+ def _atanh(x):
+ eps = 1.0e-10
+ # eps is to prevent crashes if x is exactly 0 or 1.
+ # we'll just end up returning a fairly large value.
+ return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0
+
+ def _approx_inverse_erf(x):
+ # 1 / (sqrt(pi) * ln(2)),
+ # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions
+ # this approximation is extremely crude and gets progressively worse for
+ # x very close to -1 or +1, but we mostly care about the "middle" region
+ # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772,
+ # and math.erf(0.0407316414078772) = 0.045935330944660666,
+ # which is pretty close to 0.05.
+ return 0.8139535143 * _atanh(x)
+
+ # first convert x from the range 0..1 to the range -1..1 which the error
+ # function returns
+ x = -1 + (2 * x)
+ return _approx_inverse_erf(x)
+
+ min_mean = _proportion_positive_to_mean(float(self.min_positive))
+ max_mean = _proportion_positive_to_mean(float(self.max_positive))
+ min_rms = _abs_to_rms(float(self.min_abs))
+ max_rms = _abs_to_rms(float(self.max_abs))
+ grad_scale = float(self.grad_scale)
+
+ assert x.shape[self.channel_dim] == self.num_channels
+
+ return BalancerFunction.apply(
+ x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim
+ )
+ else:
+ return _no_op(x)
+
+
+def penalize_abs_values_gt(
+ x: Tensor, limit: float, penalty: float, name: str = None
+) -> Tensor:
+ """
+ Returns x unmodified, but in backprop will put a penalty for the excess of
+ the absolute values of elements of x over the limit "limit". E.g. if
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
+
+ Caution: the value of this penalty will be affected by grad scaling used
+ in automatic mixed precision training. For this reasons we use this,
+ it shouldn't really matter, or may even be helpful; we just use this
+ to disallow really implausible values of scores to be given to softmax.
+
+ The name is for randomly printed debug info.
+ """
+ x_sign = x.sign()
+ over_limit = (x.abs() - limit) > 0
+ # The following is a memory efficient way to penalize the absolute values of
+ # x that's over the limit. (The memory efficiency comes when you think
+ # about which items torch needs to cache for the autograd, and which ones it
+ # can throw away). The numerical value of aux_loss as computed here will
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
+ # limit).relu().
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
+ # sum() due to how with_loss() works.
+ x = with_loss(x, aux_loss, name)
+ # you must use x for something, or this will be ineffective.
+ return x
+
+
+def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
+ if x.ndim == 2:
+ return x.diag()
+ else:
+ (batch, dim, dim) = x.shape
+ x = x.reshape(batch, dim * dim)
+ x = x[:, :: dim + 1]
+ assert x.shape == (batch, dim)
+ return x
+
+
+def _whitening_metric(x: Tensor, num_groups: int):
+ """
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
+ of the centered feature covariance are the same within each group's covariance matrix
+ and also between groups.
+ Args:
+ x: a Tensor of shape (*, num_channels)
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
+ Returns:
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
+ greater than 1.0 otherwise.
+ """
+ assert x.dtype != torch.float16
+ x = x.reshape(-1, x.shape[-1])
+ (num_frames, num_channels) = x.shape
+ assert num_channels % num_groups == 0
+ channels_per_group = num_channels // num_groups
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
+ # x now has shape (num_groups, num_frames, channels_per_group)
+ # subtract the mean so we use the centered, not uncentered, covariance.
+ # My experience has been that when we "mess with the gradients" like this,
+ # it's better not do anything that tries to move the mean around, because
+ # that can easily cause instability.
+ x = x - x.mean(dim=1, keepdim=True)
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
+ x_covar = torch.matmul(x.transpose(1, 2), x)
+ x_covar_mean_diag = _diag(x_covar).mean()
+ # the following expression is what we'd get if we took the matrix product
+ # of each covariance and measured the mean of its trace, i.e.
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
+ x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
+ return metric
+
+
+class WhiteningPenaltyFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, module: nn.Module) -> Tensor:
+ ctx.save_for_backward(x)
+ ctx.module = module
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor):
+ (x_orig,) = ctx.saved_tensors
+ w = ctx.module
+
+ try:
+ with torch.enable_grad():
+ with torch.amp.autocast("cuda", enabled=False):
+ x_detached = x_orig.to(torch.float32).detach()
+ x_detached.requires_grad = True
+
+ metric = _whitening_metric(x_detached, w.num_groups)
+
+ if random.random() < 0.005 or __name__ == "__main__":
+ logging.debug(
+ f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, "
+ f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}"
+ )
+
+ if metric < float(w.whitening_limit):
+ w.prob = w.min_prob
+ return x_grad, None
+ else:
+ w.prob = w.max_prob
+ metric.backward()
+ penalty_grad = x_detached.grad
+ scale = w.grad_scale * (
+ x_grad.to(torch.float32).norm()
+ / (penalty_grad.norm() + 1.0e-20)
+ )
+ penalty_grad = penalty_grad * scale
+ return x_grad + penalty_grad.to(x_grad.dtype), None
+ except Exception as e:
+ logging.info(
+ f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue."
+ )
+ return x_grad, None
+
+
+class Whiten(nn.Module):
+ def __init__(
+ self,
+ num_groups: int,
+ whitening_limit: FloatLike,
+ prob: Union[float, Tuple[float, float]],
+ grad_scale: FloatLike,
+ ):
+ """
+ Args:
+ num_groups: the number of groups to divide the channel dim into before
+ whitening. We will attempt to make the feature covariance
+ within each group, after mean subtraction, as "white" as possible,
+ while having the same trace across all groups.
+ whitening_limit: a value greater than 1.0, that dictates how much
+ freedom we have to violate the constraints. 1.0 would mean perfectly
+ white, with exactly the same trace across groups; larger values
+ give more freedom. E.g. 2.0.
+ prob: the probability with which we apply the gradient modification
+ (also affects the grad scale). May be supplied as a float,
+ or as a pair (min_prob, max_prob)
+
+ grad_scale: determines the scale on the gradient term from this object,
+ relative to the rest of the gradient on the attention weights.
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
+ """
+ super(Whiten, self).__init__()
+ assert num_groups >= 1
+ assert float(whitening_limit) >= 1
+ assert grad_scale >= 0
+ self.num_groups = num_groups
+ self.whitening_limit = whitening_limit
+ self.grad_scale = grad_scale
+
+ if isinstance(prob, float):
+ prob = (prob, prob)
+ (self.min_prob, self.max_prob) = prob
+ assert 0 < self.min_prob <= self.max_prob <= 1
+ self.prob = self.max_prob
+ self.name = None # will be set in training loop
+
+ def forward(self, x: Tensor) -> Tensor:
+ """
+ In the forward pass, this function just returns the input unmodified.
+ In the backward pass, it will modify the gradients to ensure that the
+ distribution in each group has close to (lambda times I) as the covariance
+ after mean subtraction, with the same lambda across groups.
+ For whitening_limit > 1, there will be more freedom to violate this
+ constraint.
+
+ Args:
+ x: the input of shape (*, num_channels)
+
+ Returns:
+ x, unmodified. You should make sure
+ you use the returned value, or the graph will be freed
+ and nothing will happen in backprop.
+ """
+ grad_scale = float(self.grad_scale)
+ if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
+ return _no_op(x)
+ else:
+ return WhiteningPenaltyFunction.apply(x, self)
+
+
+class WithLoss(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, y: Tensor, name: str):
+ ctx.y_shape = y.shape
+ if random.random() < 0.002 and name is not None:
+ loss_sum = y.sum().item()
+ logging.debug(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}")
+ return x
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor):
+ return (
+ ans_grad,
+ torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
+ None,
+ )
+
+
+def with_loss(x, y, name):
+ # returns x but adds y.sum() to the loss function.
+ return WithLoss.apply(x, y, name)
+
+
+class ScaleGradFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, alpha: float) -> Tensor:
+ ctx.alpha = alpha
+ return x
+
+ @staticmethod
+ def backward(ctx, grad: Tensor):
+ return grad * ctx.alpha, None
+
+
+def scale_grad(x: Tensor, alpha: float):
+ return ScaleGradFunction.apply(x, alpha)
+
+
+class ScaleGrad(nn.Module):
+ def __init__(self, alpha: float):
+ super().__init__()
+ self.alpha = alpha
+
+ def forward(self, x: Tensor) -> Tensor:
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
+ return x
+ return scale_grad(x, self.alpha)
+
+
+class LimitParamValue(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, min: float, max: float):
+ ctx.save_for_backward(x)
+ assert max >= min
+ ctx.min = min
+ ctx.max = max
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor):
+ (x,) = ctx.saved_tensors
+ # where x < ctx.min, ensure all grads are negative (this will tend to make
+ # x more positive).
+ x_grad = x_grad * torch.where(
+ torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0
+ )
+ # where x > ctx.max, ensure all grads are positive (this will tend to make
+ # x more negative).
+ x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
+ return x_grad, None, None
+
+
+def limit_param_value(
+ x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True
+):
+ # You apply this to (typically) an nn.Parameter during training to ensure that its
+ # (elements mostly) stays within a supplied range. This is done by modifying the
+ # gradients in backprop.
+ # It's not necessary to do this on every batch: do it only some of the time,
+ # to save a little time.
+ if training and random.random() < prob:
+ return LimitParamValue.apply(x, min, max)
+ else:
+ return x
+
+
+def _no_op(x: Tensor) -> Tensor:
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x
+ else:
+ # a no-op function that will have a node in the autograd graph,
+ # to avoid certain bugs relating to backward hooks
+ return x.chunk(1, dim=-1)[0]
+
+
+class Identity(torch.nn.Module):
+ def __init__(self):
+ super(Identity, self).__init__()
+
+ def forward(self, x):
+ return _no_op(x)
+
+
+class DoubleSwishFunction(torch.autograd.Function):
+ """
+ double_swish(x) = x * torch.sigmoid(x-1)
+
+ This is a definition, originally motivated by its close numerical
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
+
+ Memory-efficient derivative computation:
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
+ Now, s'(x) = s(x) * (1-s(x)).
+ double_swish'(x) = x * s'(x) + s(x).
+ = x * s(x) * (1-s(x)) + s(x).
+ = double_swish(x) * (1-s(x)) + s(x)
+ ... so we just need to remember s(x) but not x itself.
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ requires_grad = x.requires_grad
+ if x.dtype == torch.float16:
+ x = x.to(torch.float32)
+
+ s = torch.sigmoid(x - 1.0)
+ y = x * s
+
+ if requires_grad:
+ deriv = y * (1 - s) + s
+
+ # notes on derivative of x * sigmoid(x - 1):
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
+ # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
+ # floors), should be expectation-preserving.
+ floor = -0.044
+ ceil = 1.2
+ d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
+ deriv
+ )
+ if __name__ == "__main__":
+ # for self-testing only.
+ assert d_scaled.min() >= 0.0
+ assert d_scaled.max() < 256.0
+ d_int = d_scaled.to(torch.uint8)
+ ctx.save_for_backward(d_int)
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
+ y = y.to(torch.float16)
+ return y
+
+ @staticmethod
+ def backward(ctx, y_grad: Tensor) -> Tensor:
+ (d,) = ctx.saved_tensors
+ # the same constants as used in forward pass.
+ floor = -0.043637
+ ceil = 1.2
+
+ d = d * ((ceil - floor) / 255.0) + floor
+ return y_grad * d
+
+
+class DoubleSwish(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
+ that we approximate closely with x * sigmoid(x-1).
+ """
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x * torch.sigmoid(x - 1.0)
+ return DoubleSwishFunction.apply(x)
+
+
+# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates.
+class Dropout2(nn.Module):
+ def __init__(self, p: FloatLike):
+ super().__init__()
+ self.p = p
+
+ def forward(self, x: Tensor) -> Tensor:
+ return torch.nn.functional.dropout(x, p=float(self.p), training=self.training)
+
+
+class MulForDropout3(torch.autograd.Function):
+ # returns (x * y * alpha) where alpha is a float and y doesn't require
+ # grad and is zero-or-one.
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, x, y, alpha):
+ assert not y.requires_grad
+ ans = x * y * alpha
+ ctx.save_for_backward(ans)
+ ctx.alpha = alpha
+ return ans
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, ans_grad):
+ (ans,) = ctx.saved_tensors
+ x_grad = ctx.alpha * ans_grad * (ans != 0)
+ return x_grad, None, None
+
+
+# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
+# and it lets you choose one dimension to share the dropout mask over
+class Dropout3(nn.Module):
+ def __init__(self, p: FloatLike, shared_dim: int):
+ super().__init__()
+ self.p = p
+ self.shared_dim = shared_dim
+
+ def forward(self, x: Tensor) -> Tensor:
+ p = float(self.p)
+ if not self.training or p == 0:
+ return _no_op(x)
+ scale = 1.0 / (1 - p)
+ rand_shape = list(x.shape)
+ rand_shape[self.shared_dim] = 1
+ mask = torch.rand(*rand_shape, device=x.device) > p
+ ans = MulForDropout3.apply(x, mask, scale)
+ return ans
+
+
+class SwooshLFunction(torch.autograd.Function):
+ """
+ swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ requires_grad = x.requires_grad
+ if x.dtype == torch.float16:
+ x = x.to(torch.float32)
+
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+
+ coeff = -0.08
+
+ with torch.amp.autocast("cuda", enabled=False):
+ with torch.enable_grad():
+ x = x.detach()
+ x.requires_grad = True
+ y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035
+
+ if not requires_grad:
+ return y
+
+ y.backward(gradient=torch.ones_like(y))
+
+ grad = x.grad
+ floor = coeff
+ ceil = 1.0 + coeff + 0.005
+
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
+ grad
+ )
+ if __name__ == "__main__":
+ # for self-testing only.
+ assert d_scaled.min() >= 0.0
+ assert d_scaled.max() < 256.0
+
+ d_int = d_scaled.to(torch.uint8)
+ ctx.save_for_backward(d_int)
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
+ y = y.to(torch.float16)
+ return y
+
+ @staticmethod
+ def backward(ctx, y_grad: Tensor) -> Tensor:
+ (d,) = ctx.saved_tensors
+ # the same constants as used in forward pass.
+
+ coeff = -0.08
+ floor = coeff
+ ceil = 1.0 + coeff + 0.005
+ d = d * ((ceil - floor) / 255.0) + floor
+ return y_grad * d
+
+
+class SwooshL(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-L activation."""
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
+ if not x.requires_grad:
+ return k2.swoosh_l_forward(x)
+ else:
+ return k2.swoosh_l(x)
+ # return SwooshLFunction.apply(x)
+
+
+class SwooshLOnnx(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-L activation."""
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035
+
+
+class SwooshRFunction(torch.autograd.Function):
+ """
+ swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
+
+ derivatives are between -0.08 and 0.92.
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ requires_grad = x.requires_grad
+
+ if x.dtype == torch.float16:
+ x = x.to(torch.float32)
+
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+
+ with torch.amp.autocast("cuda", enabled=False):
+ with torch.enable_grad():
+ x = x.detach()
+ x.requires_grad = True
+ y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
+
+ if not requires_grad:
+ return y
+ y.backward(gradient=torch.ones_like(y))
+
+ grad = x.grad
+ floor = -0.08
+ ceil = 0.925
+
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
+ grad
+ )
+ if __name__ == "__main__":
+ # for self-testing only.
+ assert d_scaled.min() >= 0.0
+ assert d_scaled.max() < 256.0
+
+ d_int = d_scaled.to(torch.uint8)
+ ctx.save_for_backward(d_int)
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
+ y = y.to(torch.float16)
+ return y
+
+ @staticmethod
+ def backward(ctx, y_grad: Tensor) -> Tensor:
+ (d,) = ctx.saved_tensors
+ # the same constants as used in forward pass.
+ floor = -0.08
+ ceil = 0.925
+ d = d * ((ceil - floor) / 255.0) + floor
+ return y_grad * d
+
+
+class SwooshR(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-R activation."""
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
+ if not x.requires_grad:
+ return k2.swoosh_r_forward(x)
+ else:
+ return k2.swoosh_r(x)
+ # return SwooshRFunction.apply(x)
+
+
+class SwooshROnnx(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swoosh-R activation."""
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
+ return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687
+
+
+# simple version of SwooshL that does not redefine the backprop, used in
+# ActivationDropoutAndLinearFunction.
+def SwooshLForward(x: Tensor):
+ x_offset = x - 4.0
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
+ return log_sum - 0.08 * x - 0.035
+
+
+# simple version of SwooshR that does not redefine the backprop, used in
+# ActivationDropoutAndLinearFunction.
+def SwooshRForward(x: Tensor):
+ x_offset = x - 1.0
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
+ return log_sum - 0.08 * x - 0.313261687
+
+
+class ActivationDropoutAndLinearFunction(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(
+ ctx,
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor],
+ activation: str,
+ dropout_p: float,
+ dropout_shared_dim: Optional[int],
+ ):
+ if dropout_p != 0.0:
+ dropout_shape = list(x.shape)
+ if dropout_shared_dim is not None:
+ dropout_shape[dropout_shared_dim] = 1
+ # else it won't be very memory efficient.
+ dropout_mask = (1.0 / (1.0 - dropout_p)) * (
+ torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p
+ )
+ else:
+ dropout_mask = None
+
+ ctx.save_for_backward(x, weight, bias, dropout_mask)
+
+ ctx.activation = activation
+
+ forward_activation_dict = {
+ "SwooshL": k2.swoosh_l_forward,
+ "SwooshR": k2.swoosh_r_forward,
+ }
+ # it will raise a KeyError if this fails. This will be an error. We let it
+ # propagate to the user.
+ activation_func = forward_activation_dict[activation]
+ x = activation_func(x)
+ if dropout_mask is not None:
+ x = x * dropout_mask
+ x = torch.nn.functional.linear(x, weight, bias)
+ return x
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, ans_grad: Tensor):
+ saved = ctx.saved_tensors
+ (x, weight, bias, dropout_mask) = saved
+
+ forward_and_deriv_activation_dict = {
+ "SwooshL": k2.swoosh_l_forward_and_deriv,
+ "SwooshR": k2.swoosh_r_forward_and_deriv,
+ }
+ # the following lines a KeyError if the activation is unrecognized.
+ # This will be an error. We let it propagate to the user.
+ func = forward_and_deriv_activation_dict[ctx.activation]
+
+ y, func_deriv = func(x)
+ if dropout_mask is not None:
+ y = y * dropout_mask
+ # now compute derivative of y w.r.t. weight and bias..
+ # y: (..., in_channels), ans_grad: (..., out_channels),
+ (out_channels, in_channels) = weight.shape
+
+ in_channels = y.shape[-1]
+ g = ans_grad.reshape(-1, out_channels)
+ weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels))
+ y_deriv = torch.matmul(ans_grad, weight)
+ bias_deriv = None if bias is None else g.sum(dim=0)
+ x_deriv = y_deriv * func_deriv
+ if dropout_mask is not None:
+ # order versus func_deriv does not matter
+ x_deriv = x_deriv * dropout_mask
+
+ return x_deriv, weight_deriv, bias_deriv, None, None, None
+
+
+class ActivationDropoutAndLinear(torch.nn.Module):
+ """
+ This merges an activation function followed by dropout and then a nn.Linear module;
+ it does so in a memory efficient way so that it only stores the input to the whole
+ module. If activation == SwooshL and dropout_shared_dim != None, this will be
+ equivalent to:
+ nn.Sequential(SwooshL(),
+ Dropout3(dropout_p, shared_dim=dropout_shared_dim),
+ ScaledLinear(in_channels, out_channels, bias=bias,
+ initial_scale=initial_scale))
+ If dropout_shared_dim is None, the dropout would be equivalent to
+ Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout
+ mask is smaller.
+
+ Args:
+ in_channels: number of input channels, e.g. 256
+ out_channels: number of output channels, e.g. 256
+ bias: if true, have a bias
+ activation: the activation function, for now just support SwooshL.
+ dropout_p: the dropout probability or schedule (happens after nonlinearity).
+ dropout_shared_dim: the dimension, if any, across which the dropout mask is
+ shared (e.g. the time dimension). If None, this may be less memory
+ efficient if there are modules before this one that cache the input
+ for their backprop (e.g. Balancer or Whiten).
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ bias: bool = True,
+ activation: str = "SwooshL",
+ dropout_p: FloatLike = 0.0,
+ dropout_shared_dim: Optional[int] = -1,
+ initial_scale: float = 1.0,
+ ):
+ super().__init__()
+ # create a temporary module of nn.Linear that we'll steal the
+ # weights and bias from
+ l = ScaledLinear(
+ in_channels, out_channels, bias=bias, initial_scale=initial_scale
+ )
+
+ self.weight = l.weight
+ # register_parameter properly handles making it a parameter when l.bias
+ # is None. I think there is some reason for doing it this way rather
+ # than just setting it to None but I don't know what it is, maybe
+ # something to do with exporting the module..
+ self.register_parameter("bias", l.bias)
+
+ self.activation = activation
+ self.dropout_p = dropout_p
+ self.dropout_shared_dim = dropout_shared_dim
+
+ def forward(self, x: Tensor):
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ if self.activation == "SwooshL":
+ x = SwooshLForward(x)
+ elif self.activation == "SwooshR":
+ x = SwooshRForward(x)
+ else:
+ assert False, self.activation
+ return torch.nn.functional.linear(x, self.weight, self.bias)
+
+ return ActivationDropoutAndLinearFunction.apply(
+ x,
+ self.weight,
+ self.bias,
+ self.activation,
+ float(self.dropout_p),
+ self.dropout_shared_dim,
+ )
+
+
+def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
+ if num_channels <= x.shape[-1]:
+ return x[..., :num_channels]
+ else:
+ shape = list(x.shape)
+ shape[-1] = num_channels - shape[-1]
+ zeros = torch.zeros(shape, dtype=x.dtype, device=x.device)
+ return torch.cat((x, zeros), dim=-1)
+
+
+def _test_whiten():
+ for proportion in [0.1, 0.5, 10.0]:
+ logging.info(f"_test_whiten(): proportion = {proportion}")
+ x = torch.randn(100, 128)
+ direction = torch.randn(128)
+ coeffs = torch.randn(100, 1)
+ x += proportion * direction * coeffs
+
+ x.requires_grad = True
+
+ m = Whiten(
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
+ ) # grad_scale
+
+ for _ in range(4):
+ y = m(x)
+
+ y_grad = torch.randn_like(x)
+ y.backward(gradient=y_grad)
+
+ if proportion < 0.2:
+ assert torch.allclose(x.grad, y_grad)
+ elif proportion > 1.0:
+ assert not torch.allclose(x.grad, y_grad)
+
+
+def _test_balancer_sign():
+ probs = torch.arange(0, 1, 0.01)
+ N = 1000
+ x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
+ x = x.detach()
+ x.requires_grad = True
+ m = Balancer(
+ probs.numel(),
+ channel_dim=0,
+ min_positive=0.05,
+ max_positive=0.95,
+ min_abs=0.0,
+ prob=1.0,
+ )
+
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
+
+ y = m(x)
+ y.backward(gradient=y_grad)
+ print("_test_balancer_sign: x = ", x)
+ print("_test_balancer_sign: y grad = ", y_grad)
+ print("_test_balancer_sign: x grad = ", x.grad)
+
+
+def _test_balancer_magnitude():
+ magnitudes = torch.arange(0, 1, 0.01)
+ N = 1000
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
+ x = x.detach()
+ x.requires_grad = True
+ m = Balancer(
+ magnitudes.numel(),
+ channel_dim=0,
+ min_positive=0.0,
+ max_positive=1.0,
+ min_abs=0.2,
+ max_abs=0.7,
+ prob=1.0,
+ )
+
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
+
+ y = m(x)
+ y.backward(gradient=y_grad)
+ print("_test_balancer_magnitude: x = ", x)
+ print("_test_balancer_magnitude: y grad = ", y_grad)
+ print("_test_balancer_magnitude: x grad = ", x.grad)
+
+
+def _test_double_swish_deriv():
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ m = DoubleSwish()
+
+ tol = (1.2 - (-0.043637)) / 255.0
+ torch.autograd.gradcheck(m, x, atol=tol)
+
+ # for self-test.
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ y = m(x)
+
+
+def _test_swooshl_deriv():
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ m = SwooshL()
+
+ tol = 1.0 / 255.0
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
+
+ # for self-test.
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ y = m(x)
+
+
+def _test_swooshr_deriv():
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ m = SwooshR()
+
+ tol = 1.0 / 255.0
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
+
+ # for self-test.
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ y = m(x)
+
+
+def _test_softmax():
+ a = torch.randn(2, 10, dtype=torch.float64)
+ b = a.clone()
+ a.requires_grad = True
+ b.requires_grad = True
+ a.softmax(dim=1)[:, 0].sum().backward()
+ print("a grad = ", a.grad)
+ softmax(b, dim=1)[:, 0].sum().backward()
+ print("b grad = ", b.grad)
+ assert torch.allclose(a.grad, b.grad)
+
+
+def _test_piecewise_linear():
+ p = PiecewiseLinear((0, 10.0))
+ for x in [-100, 0, 100]:
+ assert p(x) == 10.0
+ p = PiecewiseLinear((0, 10.0), (1, 0.0))
+ for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]:
+ print("x, y = ", x, y)
+ assert p(x) == y, (x, p(x), y)
+
+ q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
+ x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0]
+ pq = p.max(q)
+ for x in x_vals:
+ y1 = max(p(x), q(x))
+ y2 = pq(x)
+ assert abs(y1 - y2) < 0.001
+ pq = p.min(q)
+ for x in x_vals:
+ y1 = min(p(x), q(x))
+ y2 = pq(x)
+ assert abs(y1 - y2) < 0.001
+ pq = p + q
+ for x in x_vals:
+ y1 = p(x) + q(x)
+ y2 = pq(x)
+ assert abs(y1 - y2) < 0.001
+
+
+def _test_activation_dropout_and_linear():
+ in_channels = 20
+ out_channels = 30
+
+ for bias in [True, False]:
+ # actually we don't test for dropout_p != 0.0 because forward functions will give
+ # different answers. This is because we are using the k2 implementation of
+ # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn()
+ # internally, messing up the random state.
+ for dropout_p in [0.0]:
+ for activation in ["SwooshL", "SwooshR"]:
+ m1 = nn.Sequential(
+ SwooshL() if activation == "SwooshL" else SwooshR(),
+ Dropout3(p=dropout_p, shared_dim=-1),
+ ScaledLinear(
+ in_channels, out_channels, bias=bias, initial_scale=0.5
+ ),
+ )
+ m2 = ActivationDropoutAndLinear(
+ in_channels,
+ out_channels,
+ bias=bias,
+ initial_scale=0.5,
+ activation=activation,
+ dropout_p=dropout_p,
+ )
+ with torch.no_grad():
+ m2.weight[:] = m1[2].weight
+ if bias:
+ m2.bias[:] = m1[2].bias
+ # make sure forward gives same result.
+ x1 = torch.randn(10, in_channels)
+ x1.requires_grad = True
+
+ # TEMP.
+ assert torch.allclose(
+ SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03
+ )
+
+ x2 = x1.clone().detach()
+ x2.requires_grad = True
+ seed = 10
+ torch.manual_seed(seed)
+ y1 = m1(x1)
+ y_grad = torch.randn_like(y1)
+ y1.backward(gradient=y_grad)
+ torch.manual_seed(seed)
+ y2 = m2(x2)
+ y2.backward(gradient=y_grad)
+
+ print(
+ f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}"
+ )
+ print("y1 = ", y1)
+ print("y2 = ", y2)
+ assert torch.allclose(y1, y2, atol=0.02)
+ assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05)
+ if bias:
+ assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05)
+ print("x1.grad = ", x1.grad)
+ print("x2.grad = ", x2.grad)
+
+ def isclose(a, b):
+ # return true if cosine similarity is > 0.9.
+ return (a * b).sum() > 0.9 * (
+ (a**2).sum() * (b**2).sum()
+ ).sqrt()
+
+ # the SwooshL() implementation has a noisy gradient due to 1-byte
+ # storage of it.
+ assert isclose(x1.grad, x2.grad)
+
+
+if __name__ == "__main__":
+ logging.getLogger().setLevel(logging.DEBUG)
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ _test_piecewise_linear()
+ _test_softmax()
+ _test_whiten()
+ _test_balancer_sign()
+ _test_balancer_magnitude()
+ _test_double_swish_deriv()
+ _test_swooshr_deriv()
+ _test_swooshl_deriv()
+ _test_activation_dropout_and_linear()
diff --git a/egs/zipvoice/zipvoice/solver.py b/egs/zipvoice/zipvoice/solver.py
new file mode 100644
index 000000000..a1e316ec8
--- /dev/null
+++ b/egs/zipvoice/zipvoice/solver.py
@@ -0,0 +1,277 @@
+#!/usr/bin/env python3
+# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+
+
+class DiffusionModel(torch.nn.Module):
+ """A wrapper of diffusion models for inference.
+ Args:
+ model: The diffusion model.
+ distill: Whether it is a distillation model.
+ """
+
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ distill: bool = False,
+ func_name: str = "forward_fm_decoder",
+ ):
+ super().__init__()
+ self.model = model
+ self.distill = distill
+ self.func_name = func_name
+ self.model_func = getattr(self.model, func_name)
+
+ def forward(
+ self,
+ t: torch.Tensor,
+ x: torch.Tensor,
+ text_condition: torch.Tensor,
+ speech_condition: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ guidance_scale: Union[float, torch.Tensor] = 0.0,
+ **kwargs
+ ) -> torch.Tensor:
+ """
+ Forward function that Handles the classifier-free guidance.
+ Args:
+ t: The current timestep, a tensor of shape (batch, 1, 1) or a tensor of a single float.
+ x: The initial value, with the shape (batch, seq_len, emb_dim).
+ text_condition: The text_condition of the diffision model, with the shape (batch, seq_len, emb_dim).
+ speech_condition: The speech_condition of the diffision model, with the shape (batch, seq_len, emb_dim).
+ padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len).
+ guidance_scale: The scale of classifier-free guidance, a float or a tensor of shape (batch, 1, 1).
+ Retrun:
+ The prediction with the shape (batch, seq_len, emb_dim).
+ """
+ if not torch.is_tensor(guidance_scale):
+ guidance_scale = torch.tensor(
+ guidance_scale, dtype=t.dtype, device=t.device
+ )
+ if self.distill:
+ return self.model_func(
+ t=t,
+ xt=x,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ guidance_scale=guidance_scale,
+ **kwargs
+ )
+
+ if (guidance_scale == 0.0).all():
+ return self.model_func(
+ t=t,
+ xt=x,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ **kwargs
+ )
+ else:
+ if t.dim() != 0:
+ t = torch.cat([t] * 2, dim=0)
+
+ x = torch.cat([x] * 2, dim=0)
+ padding_mask = torch.cat([padding_mask] * 2, dim=0)
+
+ text_condition = torch.cat(
+ [torch.zeros_like(text_condition), text_condition], dim=0
+ )
+
+ if t.dim() == 0:
+ if t > 0.5:
+ speech_condition = torch.cat(
+ [torch.zeros_like(speech_condition), speech_condition], dim=0
+ )
+ else:
+ guidance_scale = guidance_scale * 2
+ speech_condition = torch.cat(
+ [speech_condition, speech_condition], dim=0
+ )
+ else:
+ assert t.dim() > 0, t
+ larger_t_index = (t > 0.5).squeeze(1).squeeze(1)
+ zero_speech_condition = torch.cat(
+ [torch.zeros_like(speech_condition), speech_condition], dim=0
+ )
+ speech_condition = torch.cat(
+ [speech_condition, speech_condition], dim=0
+ )
+ speech_condition[larger_t_index] = zero_speech_condition[larger_t_index]
+ guidance_scale[~larger_t_index[: larger_t_index.size(0) // 2]] *= 2
+
+ data_uncond, data_cond = self.model_func(
+ t=t,
+ xt=x,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ **kwargs
+ ).chunk(2, dim=0)
+
+ res = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond
+ return res
+
+
+class EulerSolver:
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ distill: bool = False,
+ func_name: str = "forward_fm_decoder",
+ ):
+ """Construct a Euler Solver
+ Args:
+ model: The diffusion model.
+ distill: Whether it is distillation model.
+ """
+
+ self.model = DiffusionModel(model, distill=distill, func_name=func_name)
+
+ def sample(
+ self,
+ x: torch.Tensor,
+ text_condition: torch.Tensor,
+ speech_condition: torch.Tensor,
+ padding_mask: torch.Tensor,
+ num_step: int = 10,
+ guidance_scale: Union[float, torch.Tensor] = 0.0,
+ t_start: Union[float, torch.Tensor] = 0.0,
+ t_end: Union[float, torch.Tensor] = 1.0,
+ t_shift: float = 1.0,
+ **kwargs
+ ) -> torch.Tensor:
+ """
+ Compute the sample at time `t_end` by Euler Solver.
+ Args:
+ x: The initial value at time `t_start`, with the shape (batch, seq_len, emb_dim).
+ text_condition: The text condition of the diffision mode, with the shape (batch, seq_len, emb_dim).
+ speech_condition: The speech condition of the diffision model, with the shape (batch, seq_len, emb_dim).
+ padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len).
+ num_step: The number of ODE steps.
+ guidance_scale: The scale for classifier-free guidance, which is
+ a float or a tensor with the shape (batch, 1, 1).
+ t_start: the start timestep in the range of [0, 1],
+ which is a float or a tensor with the shape (batch, 1, 1).
+ t_end: the end time_step in the range of [0, 1],
+ which is a float or a tensor with the shape (batch, 1, 1).
+ t_shift: shift the t toward smaller numbers so that the sampling
+ will emphasize low SNR region. Should be in the range of (0, 1].
+ The shifting will be more significant when the number is smaller.
+
+ Returns:
+ The approximated solution at time `t_end`.
+ """
+ device = x.device
+
+ if torch.is_tensor(t_start) and t_start.dim() > 0:
+ timesteps = get_time_steps_batch(
+ t_start=t_start,
+ t_end=t_end,
+ num_step=num_step,
+ t_shift=t_shift,
+ device=device,
+ )
+ else:
+ timesteps = get_time_steps(
+ t_start=t_start,
+ t_end=t_end,
+ num_step=num_step,
+ t_shift=t_shift,
+ device=device,
+ )
+ for step in range(num_step):
+ v = self.model(
+ t=timesteps[step],
+ x=x,
+ text_condition=text_condition,
+ speech_condition=speech_condition,
+ padding_mask=padding_mask,
+ guidance_scale=guidance_scale,
+ **kwargs
+ )
+ x = x + v * (timesteps[step + 1] - timesteps[step])
+ return x
+
+
+def get_time_steps(
+ t_start: float = 0.0,
+ t_end: float = 1.0,
+ num_step: int = 10,
+ t_shift: float = 1.0,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Compute the intermediate time steps for sampling.
+
+ Args:
+ t_start: The starting time of the sampling (default is 0).
+ t_end: The starting time of the sampling (default is 1).
+ num_step: The number of sampling.
+ t_shift: shift the t toward smaller numbers so that the sampling
+ will emphasize low SNR region. Should be in the range of (0, 1].
+ The shifting will be more significant when the number is smaller.
+ device: A torch device.
+ Returns:
+ The time step with the shape (num_step + 1,).
+ """
+
+ timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device)
+
+ timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
+
+ return timesteps
+
+
+def get_time_steps_batch(
+ t_start: torch.Tensor,
+ t_end: torch.Tensor,
+ num_step: int = 10,
+ t_shift: float = 1.0,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Compute the intermediate time steps for sampling in the batch mode.
+
+ Args:
+ t_start: The starting time of the sampling (default is 0), with the shape (batch, 1, 1).
+ t_end: The starting time of the sampling (default is 1), with the shape (batch, 1, 1).
+ num_step: The number of sampling.
+ t_shift: shift the t toward smaller numbers so that the sampling
+ will emphasize low SNR region. Should be in the range of (0, 1].
+ The shifting will be more significant when the number is smaller.
+ device: A torch device.
+ Returns:
+ The time step with the shape (num_step + 1, N, 1, 1).
+ """
+ while t_start.dim() > 1 and t_start.size(-1) == 1:
+ t_start = t_start.squeeze(-1)
+ while t_end.dim() > 1 and t_end.size(-1) == 1:
+ t_end = t_end.squeeze(-1)
+ assert t_start.dim() == t_end.dim() == 1
+
+ timesteps_shape = (num_step + 1, t_start.size(0))
+ timesteps = torch.zeros(timesteps_shape, device=device)
+
+ for i in range(t_start.size(0)):
+ timesteps[:, i] = torch.linspace(t_start[i], t_end[i], steps=num_step + 1)
+
+ timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
+
+ return timesteps.unsqueeze(-1).unsqueeze(-1)
diff --git a/egs/zipvoice/zipvoice/tokenizer.py b/egs/zipvoice/zipvoice/tokenizer.py
new file mode 100644
index 000000000..87af061e6
--- /dev/null
+++ b/egs/zipvoice/zipvoice/tokenizer.py
@@ -0,0 +1,570 @@
+# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao
+# Han Zhu,
+# Wei Kang)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+import unicodedata
+from functools import reduce
+from typing import Dict, List, Optional
+
+import cn2an
+import inflect
+import jieba
+from pypinyin import Style, lazy_pinyin
+from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
+
+try:
+ from piper_phonemize import phonemize_espeak
+except Exception as ex:
+ raise RuntimeError(
+ f"{ex}\nPlease run\n"
+ "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html"
+ )
+
+_inflect = inflect.engine()
+_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
+_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
+_percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
+_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
+_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
+_fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
+_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
+_number_re = re.compile(r"[0-9]+")
+
+# List of (regular expression, replacement) pairs for abbreviations:
+_abbreviations = [
+ (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
+ for x in [
+ ("mrs", "misess"),
+ ("mr", "mister"),
+ ("dr", "doctor"),
+ ("st", "saint"),
+ ("co", "company"),
+ ("jr", "junior"),
+ ("maj", "major"),
+ ("gen", "general"),
+ ("drs", "doctors"),
+ ("rev", "reverend"),
+ ("lt", "lieutenant"),
+ ("hon", "honorable"),
+ ("sgt", "sergeant"),
+ ("capt", "captain"),
+ ("esq", "esquire"),
+ ("ltd", "limited"),
+ ("col", "colonel"),
+ ("ft", "fort"),
+ ("etc", "et cetera"),
+ ("btw", "by the way"),
+ ]
+]
+
+
+def intersperse(sequence, item=0):
+ result = [item] * (len(sequence) * 2 + 1)
+ result[1::2] = sequence
+ return result
+
+
+def expand_abbreviations(text):
+ for regex, replacement in _abbreviations:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def _remove_commas(m):
+ return m.group(1).replace(",", "")
+
+
+def _expand_decimal_point(m):
+ return m.group(1).replace(".", " point ")
+
+
+def _expand_percent(m):
+ return m.group(1).replace("%", " percent ")
+
+
+def _expand_dollars(m):
+ match = m.group(1)
+ parts = match.split(".")
+ if len(parts) > 2:
+ return " " + match + " dollars " # Unexpected format
+ dollars = int(parts[0]) if parts[0] else 0
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
+ if dollars and cents:
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
+ cent_unit = "cent" if cents == 1 else "cents"
+ return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
+ elif dollars:
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
+ return " %s %s " % (dollars, dollar_unit)
+ elif cents:
+ cent_unit = "cent" if cents == 1 else "cents"
+ return " %s %s " % (cents, cent_unit)
+ else:
+ return " zero dollars "
+
+
+def fraction_to_words(numerator, denominator):
+ if numerator == 1 and denominator == 2:
+ return " one half "
+ if numerator == 1 and denominator == 4:
+ return " one quarter "
+ if denominator == 2:
+ return " " + _inflect.number_to_words(numerator) + " halves "
+ if denominator == 4:
+ return " " + _inflect.number_to_words(numerator) + " quarters "
+ return (
+ " "
+ + _inflect.number_to_words(numerator)
+ + " "
+ + _inflect.ordinal(_inflect.number_to_words(denominator))
+ + " "
+ )
+
+
+def _expand_fraction(m):
+ numerator = int(m.group(1))
+ denominator = int(m.group(2))
+ return fraction_to_words(numerator, denominator)
+
+
+def _expand_ordinal(m):
+ return " " + _inflect.number_to_words(m.group(0)) + " "
+
+
+def _expand_number(m):
+ num = int(m.group(0))
+ if num > 1000 and num < 3000:
+ if num == 2000:
+ return " two thousand "
+ elif num > 2000 and num < 2010:
+ return " two thousand " + _inflect.number_to_words(num % 100) + " "
+ elif num % 100 == 0:
+ return " " + _inflect.number_to_words(num // 100) + " hundred "
+ else:
+ return (
+ " "
+ + _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(
+ ", ", " "
+ )
+ + " "
+ )
+ else:
+ return " " + _inflect.number_to_words(num, andword="") + " "
+
+
+# Normalize numbers pronunciation
+def normalize_numbers(text):
+ text = re.sub(_comma_number_re, _remove_commas, text)
+ text = re.sub(_pounds_re, r"\1 pounds", text)
+ text = re.sub(_dollars_re, _expand_dollars, text)
+ text = re.sub(_fraction_re, _expand_fraction, text)
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
+ text = re.sub(_percent_number_re, _expand_percent, text)
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
+ text = re.sub(_number_re, _expand_number, text)
+ return text
+
+
+# Convert numbers to Chinese pronunciation
+def number_to_chinese(text):
+ text = cn2an.transform(text, "an2cn")
+ return text
+
+
+def map_punctuations(text):
+ text = text.replace(",", ",")
+ text = text.replace("。", ".")
+ text = text.replace("!", "!")
+ text = text.replace("?", "?")
+ text = text.replace(";", ";")
+ text = text.replace(":", ":")
+ text = text.replace("、", ",")
+ text = text.replace("‘", "'")
+ text = text.replace("“", '"')
+ text = text.replace("”", '"')
+ text = text.replace("’", "'")
+ text = text.replace("⋯", "…")
+ text = text.replace("···", "…")
+ text = text.replace("・・・", "…")
+ text = text.replace("...", "…")
+ return text
+
+
+def is_chinese(char):
+ if char >= "\u4e00" and char <= "\u9fa5":
+ return True
+ else:
+ return False
+
+
+def is_alphabet(char):
+ if (char >= "\u0041" and char <= "\u005a") or (
+ char >= "\u0061" and char <= "\u007a"
+ ):
+ return True
+ else:
+ return False
+
+
+def is_hangul(char):
+ letters = unicodedata.normalize("NFD", char)
+ return all(
+ ["\u1100" <= c <= "\u11ff" or "\u3131" <= c <= "\u318e" for c in letters]
+ )
+
+
+def is_japanese(char):
+ return any(
+ [
+ start <= char <= end
+ for start, end in [
+ ("\u3041", "\u3096"),
+ ("\u30a0", "\u30ff"),
+ ("\uff5f", "\uff9f"),
+ ("\u31f0", "\u31ff"),
+ ("\u3220", "\u3243"),
+ ("\u3280", "\u337f"),
+ ]
+ ]
+ )
+
+
+def get_segment(text: str) -> List[str]:
+ # sentence --> [ch_part, en_part, ch_part, ...]
+ # example :
+ # input : 我们是小米人,是吗? Yes I think so!霍...啦啦啦
+ # output : [('我们是小米人,是吗? ', 'zh'), ('Yes I think so!', 'en'), ('霍...啦啦啦', 'zh')]
+ segments = []
+ types = []
+ flag = 0
+ temp_seg = ""
+ temp_lang = ""
+
+ for i, ch in enumerate(text):
+ if is_chinese(ch):
+ types.append("zh")
+ elif is_alphabet(ch):
+ types.append("en")
+ else:
+ types.append("other")
+
+ assert len(types) == len(text)
+
+ for i in range(len(types)):
+ # find the first char of the seg
+ if flag == 0:
+ temp_seg += text[i]
+ temp_lang = types[i]
+ flag = 1
+ else:
+ if temp_lang == "other":
+ if types[i] == temp_lang:
+ temp_seg += text[i]
+ else:
+ temp_seg += text[i]
+ temp_lang = types[i]
+ else:
+ if types[i] == temp_lang:
+ temp_seg += text[i]
+ elif types[i] == "other":
+ temp_seg += text[i]
+ else:
+ segments.append((temp_seg, temp_lang))
+ temp_seg = text[i]
+ temp_lang = types[i]
+ flag = 1
+
+ segments.append((temp_seg, temp_lang))
+ return segments
+
+
+def preprocess(text: str) -> str:
+ text = map_punctuations(text)
+ return text
+
+
+def tokenize_ZH(text: str) -> List[str]:
+ try:
+ text = number_to_chinese(text)
+ segs = list(jieba.cut(text))
+ full = lazy_pinyin(
+ segs, style=Style.TONE3, tone_sandhi=True, neutral_tone_with_five=True
+ )
+ phones = []
+ for x in full:
+ # valid pinyin (in tone3 style) is alphabet + 1 number in [1-5].
+ if not (x[0:-1].isalpha() and x[-1] in ("1", "2", "3", "4", "5")):
+ phones.append(x)
+ continue
+ initial = to_initials(x, strict=False)
+ # don't want to share tokens with espeak tokens, so use tone3 style
+ final = to_finals_tone3(x, strict=False, neutral_tone_with_five=True)
+ if initial != "":
+ # don't want to share tokens with espeak tokens, so add a '0' after each initial
+ phones.append(initial + "0")
+ if final != "":
+ phones.append(final)
+ return phones
+ except:
+ return []
+
+
+def tokenize_EN(text: str) -> List[str]:
+ try:
+ text = expand_abbreviations(text)
+ text = normalize_numbers(text)
+ tokens = phonemize_espeak(text, "en-us")
+ tokens = reduce(lambda x, y: x + y, tokens)
+ return tokens
+ except:
+ return []
+
+
+class TokenizerEmilia(object):
+ def __init__(self, token_file: Optional[str] = None, token_type="phone"):
+ """
+ Args:
+ tokens: the file that contains information that maps tokens to ids,
+ which is a text file with '{token} {token_id}' per line.
+ """
+ assert (
+ token_type == "phone"
+ ), f"Only support phone tokenizer for Emilia, but get {token_type}."
+ self.has_tokens = False
+ if token_file is None:
+ logging.debug(
+ "Initialize Tokenizer without tokens file, will fail when map to ids."
+ )
+ return
+ self.token2id: Dict[str, int] = {}
+ with open(token_file, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ info = line.rstrip().split("\t")
+ token, id = info[0], int(info[1])
+ assert token not in self.token2id, token
+ self.token2id[token] = id
+ self.pad_id = self.token2id["_"] # padding
+
+ self.vocab_size = len(self.token2id)
+ self.has_tokens = True
+
+ def texts_to_token_ids(
+ self,
+ texts: List[str],
+ ) -> List[List[int]]:
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
+
+ def texts_to_tokens(
+ self,
+ texts: List[str],
+ ) -> List[List[str]]:
+ """
+ Args:
+ texts:
+ A list of transcripts.
+ Returns:
+ Return a list of a list of tokens [utterance][token]
+ """
+ for i in range(len(texts)):
+ # Text normalization
+ texts[i] = preprocess(texts[i])
+
+ phoneme_list = []
+ for text in texts:
+ # now only en and ch
+ segments = get_segment(text)
+ all_phoneme = []
+ for index in range(len(segments)):
+ seg = segments[index]
+ if seg[1] == "zh":
+ phoneme = tokenize_ZH(seg[0])
+ else:
+ if seg[1] != "en":
+ logging.warning(
+ f"The lang should be en, given {seg[1]}, skipping segment : {seg}"
+ )
+ continue
+ phoneme = tokenize_EN(seg[0])
+ all_phoneme += phoneme
+ phoneme_list.append(all_phoneme)
+ return phoneme_list
+
+ def tokens_to_token_ids(
+ self,
+ tokens: List[List[str]],
+ intersperse_blank: bool = False,
+ ) -> List[List[int]]:
+ """
+ Args:
+ tokens_list:
+ A list of token list, each corresponding to one utterance.
+ intersperse_blank:
+ Whether to intersperse blanks in the token sequence.
+
+ Returns:
+ Return a list of token id list [utterance][token_id]
+ """
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
+ token_ids = []
+
+ for tks in tokens:
+ ids = []
+ for t in tks:
+ if t not in self.token2id:
+ logging.warning(f"Skip OOV {t}")
+ continue
+ ids.append(self.token2id[t])
+
+ if intersperse_blank:
+ ids = intersperse(ids, self.pad_id)
+
+ token_ids.append(ids)
+
+ return token_ids
+
+
+class TokenizerLibriTTS(object):
+ def __init__(self, token_file: str, token_type: str):
+ """
+ Args:
+ type: the type of tokenizer, e.g., bpe, char, phone.
+ tokens: the file that contains information that maps tokens to ids,
+ which is a text file with '{token} {token_id}' per line if type is
+ char or phone, otherwise it is a bpe_model file.
+ """
+ self.type = token_type
+ assert token_type in ["bpe", "char", "phone"]
+ # Parse token file
+
+ if token_type == "bpe":
+ import sentencepiece as spm
+
+ self.sp = spm.SentencePieceProcessor()
+ self.sp.load(token_file)
+ self.pad_id = self.sp.piece_to_id("")
+ self.vocab_size = self.sp.get_piece_size()
+ else:
+ self.token2id: Dict[str, int] = {}
+ with open(token_file, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ info = line.rstrip().split("\t")
+ token, id = info[0], int(info[1])
+ assert token not in self.token2id, token
+ self.token2id[token] = id
+ self.pad_id = self.token2id["_"] # padding
+ self.vocab_size = len(self.token2id)
+ try:
+ from tacotron_cleaner.cleaners import custom_english_cleaners as cleaner
+ except Exception as ex:
+ raise RuntimeError(
+ f"{ex}\nPlease run\n"
+ "pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/"
+ )
+ self.cleaner = cleaner
+
+ def texts_to_token_ids(
+ self,
+ texts: List[str],
+ lang: str = "en-us",
+ ) -> List[List[int]]:
+ """
+ Args:
+ texts:
+ A list of transcripts.
+ intersperse_blank:
+ Whether to intersperse blanks in the token sequence.
+ Used when alignment is from MAS.
+ lang:
+ Language argument passed to phonemize_espeak().
+
+ Returns:
+ Return a list of token id list [utterance][token_id]
+ """
+ for i in range(len(texts)):
+ # Text normalization
+ texts[i] = self.cleaner(texts[i])
+
+ if self.type == "bpe":
+ token_ids_list = self.sp.encode(texts)
+
+ elif self.type == "phone":
+ token_ids_list = []
+ for text in texts:
+ tokens_list = phonemize_espeak(text.lower(), lang)
+ tokens = []
+ for t in tokens_list:
+ tokens.extend(t)
+ token_ids = []
+ for t in tokens:
+ if t not in self.token2id:
+ logging.warning(f"Skip OOV {t}")
+ continue
+ token_ids.append(self.token2id[t])
+
+ token_ids_list.append(token_ids)
+ else:
+ token_ids_list = []
+ for text in texts:
+ token_ids = []
+ for t in text:
+ if t not in self.token2id:
+ logging.warning(f"Skip OOV {t}")
+ continue
+ token_ids.append(self.token2id[t])
+
+ token_ids_list.append(token_ids)
+
+ return token_ids_list
+
+ def tokens_to_token_ids(
+ self,
+ tokens_list: List[str],
+ ) -> List[List[int]]:
+ """
+ Args:
+ tokens_list:
+ A list of token list, each corresponding to one utterance.
+
+ Returns:
+ Return a list of token id list [utterance][token_id]
+ """
+ token_ids_list = []
+
+ for tokens in tokens_list:
+ token_ids = []
+ for t in tokens:
+ if t not in self.token2id:
+ logging.warning(f"Skip OOV {t}")
+ continue
+ token_ids.append(self.token2id[t])
+
+ token_ids_list.append(token_ids)
+
+ return token_ids_list
+
+
+if __name__ == "__main__":
+ text = "我们是5年小米人,是吗? Yes I think so! mr king, 5 years, from 2019 to 2024. 霍...啦啦啦超过90%的人咯...?!9204"
+ tokenizer = Tokenizer()
+ tokens = tokenizer.texts_to_tokens([text])
+ print(f"tokens : {tokens}")
+ tokens2 = "|".join(tokens[0])
+ print(f"tokens2 : {tokens2}")
+ tokens2 = tokens2.split("|")
+ assert tokens[0] == tokens2
diff --git a/egs/zipvoice/zipvoice/train_distill.py b/egs/zipvoice/zipvoice/train_distill.py
new file mode 100644
index 000000000..ae784050b
--- /dev/null
+++ b/egs/zipvoice/zipvoice/train_distill.py
@@ -0,0 +1,1043 @@
+#!/usr/bin/env python3
+# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This script trains a ZipVoice-Distill model starting from a ZipVoice model.
+It has two distillation stages.
+
+Usage:
+
+(1) The first distillation stage with a fixed ZipVoice model as the teacher.
+python3 zipvoice/train_distill.py \
+ --world-size 8 \
+ --use-fp16 1 \
+ --tensorboard 1 \
+ --dataset "emilia" \
+ --base-lr 0.0005 \
+ --max-duration 500 \
+ --token-file "data/tokens_emilia.txt" \
+ --manifest-dir "data/fbank_emilia" \
+ --teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
+ --num-updates 60000 \
+ --distill-stage "first" \
+ --exp-dir zipvoice/exp_zipvoice_distill_1stage
+
+(2) The second distillation stage with a EMA model as the teacher.
+python3 zipvoice/train_distill.py \
+ --world-size 8 \
+ --use-fp16 1 \
+ --tensorboard 1 \
+ --dataset "emilia" \
+ --base-lr 0.0001 \
+ --max-duration 500 \
+ --token-file "data/tokens_emilia.txt" \
+ --manifest-dir "data/fbank_emilia" \
+ --teacher-model zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt \
+ --num-updates 2000 \
+ --distill-stage "second" \
+ --exp-dir zipvoice/exp_zipvoice_distill
+"""
+
+import argparse
+import copy
+import logging
+import os
+import random
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from checkpoint import load_checkpoint, save_checkpoint
+from lhotse.cut import Cut, CutSet
+from lhotse.utils import fix_random_seed
+from model import get_distill_model, get_model
+from optim import FixedLRScheduler, ScaledAdam
+from tokenizer import TokenizerEmilia, TokenizerLibriTTS
+from torch import Tensor
+from torch.amp import GradScaler, autocast
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torch.utils.tensorboard import SummaryWriter
+from train_flow import add_model_arguments, get_params
+from tts_datamodule import TtsDataModule
+from utils import (
+ condition_time_mask,
+ get_adjusted_batch_count,
+ prepare_input,
+ set_batch_count,
+)
+
+from icefall import diagnostics
+from icefall.checkpoint import (
+ remove_checkpoints,
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ make_pad_mask,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=1,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--num-updates",
+ type=int,
+ default=0,
+ help="Number of updates to train, will ignore num_epochs if > 0.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--teacher-model",
+ type=str,
+ help="""Checkpoints of pre-trained teacher model""",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipvoice/exp_zipvoice_distill",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.001, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=50,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--feat-scale",
+ type=float,
+ default=0.1,
+ help="The scale factor of fbank feature",
+ )
+
+ parser.add_argument(
+ "--ema-decay",
+ type=float,
+ default=0.9999,
+ help="The EMA decay factor of target model in distillation.",
+ )
+ parser.add_argument(
+ "--distill-stage",
+ type=str,
+ choices=["first", "second"],
+ help="The stage of distillation.",
+ )
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="emilia",
+ choices=["emilia", "libritts"],
+ help="The used training dataset",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def ema(new_model, ema_model, decay):
+ if isinstance(new_model, DDP):
+ new_model = new_model.module
+ if isinstance(ema_model, DDP):
+ ema_model = ema_model.module
+ new_model_dict = new_model.state_dict()
+ ema_model_dict = ema_model.state_dict()
+ for key in new_model_dict.keys():
+ ema_model_dict[key].data.copy_(
+ ema_model_dict[key].data * decay + new_model_dict[key].data * (1 - decay)
+ )
+
+
+def resume_checkpoint(
+ params: AttributeDict, model: nn.Module, model_avg: nn.Module, model_ema: nn.Module
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename, model=model, model_avg=model_avg, model_ema=model_ema, strict=True
+ )
+
+ if params.start_epoch > 1:
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ return saved_params
+
+
+def compute_fbank_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ teacher_model: Union[nn.Module, DDP],
+ features: Tensor,
+ features_lens: Tensor,
+ tokens: List[List[int]],
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training.
+ teacher_model:
+ The teacher model for distillation.
+ features:
+ The target acoustic feature.
+ features_lens:
+ The number of frames of each utterance.
+ tokens:
+ Input tokens that representing the transcripts.
+ durations:
+ Duration of each token.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ batch_size, num_frames, _ = features.shape
+
+ features = torch.nn.functional.pad(
+ features, (0, 0, 0, num_frames - features.size(1))
+ ) # (B, T, F)
+ noise = torch.randn_like(features) # (B, T, F)
+
+ # Sampling t and guidance_scale from uniform distribution
+
+ t_value = random.random()
+ t = torch.ones(batch_size, 1, 1, device=device) * t_value
+ if params.distill_stage == "first":
+ guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2
+ else:
+ guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2 + 1
+ xt = features * t + noise * (1 - t)
+ t_delta_fix = random.uniform(0.0, min(0.3, 1 - t_value))
+ t_delta_ema = random.uniform(0.0, min(0.3, 1 - t_value - t_delta_fix))
+ t_dest = t + t_delta_fix + t_delta_ema
+
+ with torch.no_grad():
+ speech_condition_mask = condition_time_mask(
+ features_lens=features_lens,
+ mask_percent=(0.7, 1.0),
+ max_len=features.size(1),
+ )
+
+ if params.distill_stage == "first":
+ teacher_x_t_mid, _ = teacher_model.sample_intermediate(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=xt,
+ speech_condition_mask=speech_condition_mask,
+ t_start=t,
+ t_end=t + t_delta_fix,
+ num_step=1,
+ guidance_scale=guidance_scale,
+ )
+
+ target_x1, _ = teacher_model.sample_intermediate(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=teacher_x_t_mid,
+ speech_condition_mask=speech_condition_mask,
+ t_start=t + t_delta_fix,
+ t_end=t_dest,
+ num_step=1,
+ guidance_scale=guidance_scale,
+ )
+ else:
+ teacher_x_t_mid, _ = teacher_model(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=xt,
+ speech_condition_mask=speech_condition_mask,
+ t_start=t,
+ t_end=t + t_delta_fix,
+ num_step=1,
+ guidance_scale=guidance_scale,
+ )
+
+ target_x1, _ = teacher_model(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=teacher_x_t_mid,
+ speech_condition_mask=speech_condition_mask,
+ t_start=t + t_delta_fix,
+ t_end=t_dest,
+ num_step=1,
+ guidance_scale=guidance_scale,
+ )
+
+ with torch.set_grad_enabled(is_training):
+
+ pred_x1, _ = model(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=xt,
+ speech_condition_mask=speech_condition_mask,
+ t_start=t,
+ t_end=t_dest,
+ num_step=1,
+ guidance_scale=guidance_scale,
+ )
+ pred_v = (pred_x1 - xt) / (t_dest - t)
+
+ padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T)
+ loss_mask = speech_condition_mask & (~padding_mask)
+
+ target_v = (target_x1 - xt) / (t_dest - t)
+ loss = torch.mean((pred_v[loss_mask] - target_v[loss_mask]) ** 2)
+
+ ut = features - noise # (B, T, F)
+
+ ref_loss = torch.mean((pred_v[loss_mask] - ut[loss_mask]) ** 2)
+
+ assert loss.requires_grad == is_training
+ info = MetricsTracker()
+ num_frames = features_lens.sum().item()
+ info["frames"] = num_frames
+ info["loss"] = loss.detach().cpu().item() * num_frames
+ info["ref_loss"] = ref_loss.detach().cpu().item() * num_frames
+ return loss, info
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ teacher_model: Union[nn.Module, DDP],
+ tokenizer: TokenizerEmilia,
+ optimizer: Optimizer,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ teacher_model:
+ The model for distillation.
+ tokenizer:
+ Used to convert text to tokens.
+ optimizer:
+ The optimizer.
+ scheduler:
+ The learning rate scheduler, we call step() every epoch.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to track the stats over iterations in one epoch
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ model_ema=teacher_model,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params) + 100000)
+
+ if (
+ params.valid_interval is None
+ and batch_idx == 0
+ and not params.print_diagnostics
+ ) or (
+ params.valid_interval is not None
+ and params.batch_idx_train % params.valid_interval == 0
+ and not params.print_diagnostics
+ ):
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ teacher_model=teacher_model,
+ tokenizer=tokenizer,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ params.batch_idx_train += 1
+
+ batch_size = len(batch["text"])
+
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ tokenizer=tokenizer,
+ return_tokens=True,
+ return_feature=True,
+ )
+
+ try:
+ with autocast("cuda", enabled=params.use_fp16):
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ teacher_model=teacher_model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=True,
+ )
+
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ scaler.scale(loss).backward()
+
+ scheduler.step_batch(params.batch_idx_train)
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ if params.distill_stage == "second":
+ ema(model, teacher_model, params.ema_decay)
+ except RuntimeError as e:
+ if "out of memory" in str(e):
+ logging.info(f"out of memory error at rank {rank}")
+ # optimizer.zero_grad()
+ # duration_optimizer.zero_grad()
+ torch.cuda.empty_cache()
+ raise
+ continue
+ else:
+ logging.info(f"Caught exception : {e}.")
+ save_bad_model()
+ raise
+ except Exception as e:
+ logging.info(f"Caught exception : {e}.")
+ save_bad_model()
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+ if (
+ params.batch_idx_train > 0
+ and params.num_updates > 0
+ and params.batch_idx_train > params.num_updates
+ ):
+ break
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 1024.0 or (
+ cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
+ ):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if params.batch_idx_train % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
+ f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ f"cur_lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ teacher_model: Optional[nn.Module],
+ tokenizer: TokenizerEmilia,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+
+ model.eval()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ tokenizer=tokenizer,
+ return_tokens=True,
+ return_feature=True,
+ )
+
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ teacher_model=teacher_model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ os.makedirs(f"{params.exp_dir}/fbank", exist_ok=True)
+
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ if params.dataset == "emilia":
+ tokenizer = TokenizerEmilia(
+ token_file=params.token_file, token_type=params.token_type
+ )
+ elif params.dataset == "libritts":
+ tokenizer = TokenizerLibriTTS(
+ token_file=params.token_file, token_type=params.token_type
+ )
+
+ params.vocab_size = tokenizer.vocab_size
+ params.pad_id = tokenizer.pad_id
+
+ params.device = device
+
+ logging.info(params)
+
+ logging.info("About to create model")
+
+ assert params.teacher_model is not None
+ logging.info(f"Loading pre-trained model from {params.teacher_model}")
+ model = get_distill_model(params)
+ _ = load_checkpoint(
+ filename=params.teacher_model,
+ model=model,
+ strict=(params.distill_stage == "second"),
+ )
+
+ if params.distill_stage == "first":
+ teacher_model = get_model(params)
+ _ = load_checkpoint(
+ filename=params.teacher_model, model=teacher_model, strict=True
+ )
+ else:
+ teacher_model = copy.deepcopy(model)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of parameters : {num_param}")
+
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+ assert params.start_epoch > 0, params.start_epoch
+ if params.start_epoch > 1:
+ logging.info(f"Resuming from epoch {params.start_epoch}")
+ if params.distill_stage == "first":
+ checkpoints = resume_checkpoint(
+ params=params, model=model, model_avg=model_avg
+ )
+ else:
+ checkpoints = resume_checkpoint(
+ params=params, model=model, model_avg=model_avg, model_ema=teacher_model
+ )
+
+ model = model.to(device)
+ teacher_model.to(device)
+ teacher_model.eval()
+
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ # only update the fm_decoder
+ num_trainable = 0
+ for name, p in model.named_parameters():
+ if "fm_decoder" in name:
+ p.requires_grad = True
+ num_trainable += p.numel()
+ else:
+ p.requires_grad = False
+
+ logging.info(
+ "A total of {} trainable parameters ({:.3f}% of the whole model)".format(
+ num_trainable, num_trainable / num_param * 100
+ )
+ )
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(
+ model,
+ lr=params.base_lr,
+ include_names=True,
+ ),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ scheduler = FixedLRScheduler(optimizer)
+
+ scaler = GradScaler("cuda", enabled=params.use_fp16)
+
+ if params.start_epoch > 1 and checkpoints is not None:
+ # load state_dict for optimizers
+ if "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ # load state_dict for schedulers
+ if "scheduler" in checkpoints:
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ def remove_short_and_long_utt_emilia(c: Cut):
+ if c.duration < 1.0 or c.duration > 30.0:
+ return False
+ return True
+
+ def remove_short_and_long_utt_libritts(c: Cut):
+ if c.duration < 1.0 or c.duration > 20.0:
+ return False
+ return True
+
+ datamodule = TtsDataModule(args)
+ if params.dataset == "emilia":
+ train_cuts = CutSet.mux(
+ datamodule.train_emilia_EN_cuts(),
+ datamodule.train_emilia_ZH_cuts(),
+ weights=[46000, 49000],
+ )
+ train_cuts = train_cuts.filter(remove_short_and_long_utt_emilia)
+ dev_cuts = CutSet.mux(
+ datamodule.dev_emilia_EN_cuts(),
+ datamodule.dev_emilia_ZH_cuts(),
+ weights=[0.5, 0.5],
+ )
+ elif params.dataset == "libritts":
+ train_cuts = datamodule.train_libritts_cuts()
+ train_cuts = train_cuts.filter(remove_short_and_long_utt_libritts)
+ dev_cuts = datamodule.dev_libritts_cuts()
+
+ train_dl = datamodule.train_dataloaders(train_cuts)
+
+ valid_dl = datamodule.dev_dataloaders(dev_cuts)
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ logging.info(f"Start epoch {epoch}")
+
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ params.cur_epoch = epoch
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ teacher_model=teacher_model,
+ tokenizer=tokenizer,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint(
+ filename=filename,
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ model_ema=teacher_model,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if rank == 0:
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ TtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ main()
diff --git a/egs/zipvoice/zipvoice/train_flow.py b/egs/zipvoice/zipvoice/train_flow.py
new file mode 100644
index 000000000..74d81b726
--- /dev/null
+++ b/egs/zipvoice/zipvoice/train_flow.py
@@ -0,0 +1,1108 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Wei Kang,
+# Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script trains a ZipVoice model with the flow-matching loss.
+
+Usage:
+
+python3 zipvoice/train_flow.py \
+ --world-size 8 \
+ --use-fp16 1 \
+ --dataset emilia \
+ --max-duration 500 \
+ --lr-hours 30000 \
+ --lr-batches 7500 \
+ --token-file "data/tokens_emilia.txt" \
+ --manifest-dir "data/fbank_emilia" \
+ --num-epochs 11 \
+ --exp-dir zipvoice/exp_zipvoice
+"""
+
+import argparse
+import copy
+import logging
+import os
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from checkpoint import load_checkpoint, save_checkpoint
+from lhotse.cut import Cut, CutSet
+from lhotse.utils import fix_random_seed
+from model import get_model
+from optim import Eden, ScaledAdam
+from tokenizer import TokenizerEmilia, TokenizerLibriTTS
+from torch import Tensor
+from torch.amp import GradScaler, autocast
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torch.utils.tensorboard import SummaryWriter
+from tts_datamodule import TtsDataModule
+from utils import get_adjusted_batch_count, prepare_input, set_batch_count
+
+from icefall import diagnostics
+from icefall.checkpoint import (
+ remove_checkpoints,
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ get_parameter_groups_with_lrs,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--fm-decoder-downsampling-factor",
+ type=str,
+ default="1,2,4,2,1",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--fm-decoder-num-layers",
+ type=str,
+ default="2,2,4,4,4",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--fm-decoder-cnn-module-kernel",
+ type=str,
+ default="31,15,7,15,31",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--fm-decoder-feedforward-dim",
+ type=int,
+ default=1536,
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--fm-decoder-num-heads",
+ type=int,
+ default=4,
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--fm-decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--text-encoder-downsampling-factor",
+ type=str,
+ default="1",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--text-encoder-num-layers",
+ type=str,
+ default="4",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--text-encoder-feedforward-dim",
+ type=int,
+ default=512,
+ help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--text-encoder-cnn-module-kernel",
+ type=str,
+ default="9",
+ help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--text-encoder-num-heads",
+ type=int,
+ default=4,
+ help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--text-encoder-dim",
+ type=int,
+ default=192,
+ help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=int,
+ default=32,
+ help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=int,
+ default=12,
+ help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=int,
+ default=4,
+ help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default=48,
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--time-embed-dim",
+ type=int,
+ default=192,
+ help="Embedding dimension of timestamps embedding.",
+ )
+
+ parser.add_argument(
+ "--text-embed-dim",
+ type=int,
+ default=192,
+ help="Embedding dimension of text embedding.",
+ )
+
+ parser.add_argument(
+ "--token-type",
+ type=str,
+ default="phone",
+ choices=["phone", "char", "bpe"],
+ help="Input token type of TTS model, by default, "
+ "we use phone for emilia, char for libritts.",
+ )
+
+ parser.add_argument(
+ "--token-file",
+ type=str,
+ default="data/tokens_emilia.txt",
+ help="The file that contains information that maps tokens to ids,"
+ "which is a text file with '{token}\t{token_id}' per line if type is"
+ "char or phone, otherwise it is a bpe_model file.",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=11,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ default=None,
+ help="""Checkpoints of pre-trained models, will load it if not None
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="zipvoice/exp_zipvoice",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.02, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=7500,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=10,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--lr-hours",
+ type=float,
+ default=0,
+ help="""If positive, --epoch is ignored and it specifies the number of hours
+ that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--ref-duration",
+ type=float,
+ default=50,
+ help="Reference batch duration for purposes of adjusting batch counts for setting various "
+ "schedules inside the model",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=4000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=True,
+ help="Whether to use half precision training.",
+ )
+
+ parser.add_argument(
+ "--feat-scale",
+ type=float,
+ default=0.1,
+ help="The scale factor of fbank feature",
+ )
+
+ parser.add_argument(
+ "--condition-drop-ratio",
+ type=float,
+ default=0.2,
+ help="The drop rate of text condition during training.",
+ )
+
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="emilia",
+ choices=["emilia", "libritts"],
+ help="The used training dataset",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - sampling_rate: Sampling rate of the wavform.
+
+ - frame_shift_ms: Frame shift in milliseconds.
+
+ - feat_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - env_info: A dict containing information about the environment.
+
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 4000,
+ "sampling_rate": 24000,
+ "frame_shift_ms": 256 / 24000 * 1000,
+ "feat_dim": 100,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def resume_checkpoint(
+ params: AttributeDict, model: nn.Module, model_avg: nn.Module
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ logging.info(f"Resuming from file {filename}")
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename, model=model, model_avg=model_avg, strict=True
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ return saved_params
+
+
+def compute_fbank_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ features: Tensor,
+ features_lens: Tensor,
+ tokens: List[List[int]],
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training.
+ features:
+ The target acoustic feature.
+ features_lens:
+ The number of frames of each utterance.
+ tokens:
+ Input tokens that representing the transcripts.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ batch_size, num_frames, _ = features.shape
+
+ features = torch.nn.functional.pad(
+ features, (0, 0, 0, num_frames - features.size(1))
+ ) # (B, T, F)
+ noise = torch.randn_like(features) # (B, T, F)
+
+ # Sampling t from uniform distribution
+ if is_training:
+ t = torch.rand(batch_size, 1, 1, device=device)
+ else:
+ t = (
+ (torch.arange(batch_size, device=device) / batch_size)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ )
+ with torch.set_grad_enabled(is_training):
+
+ loss = model(
+ tokens=tokens,
+ features=features,
+ features_lens=features_lens,
+ noise=noise,
+ t=t,
+ condition_drop_ratio=params.condition_drop_ratio,
+ )
+
+ assert loss.requires_grad == is_training
+ info = MetricsTracker()
+ num_frames = features_lens.sum().item()
+ info["frames"] = num_frames
+ info["loss"] = loss.detach().cpu().item() * num_frames
+
+ return loss, info
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ tokenizer: TokenizerEmilia,
+ optimizer: Optimizer,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ tokenizer:
+ Used to convert text to tokens.
+ optimizer:
+ The optimizer.
+ scheduler:
+ The learning rate scheduler, we call step() every epoch.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to track the stats over iterations in one epoch
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+
+ if batch_idx % 10 == 0:
+ set_batch_count(model, get_adjusted_batch_count(params))
+
+ if (
+ params.valid_interval is None
+ and batch_idx == 0
+ and not params.print_diagnostics
+ ) or (
+ params.valid_interval is not None
+ and params.batch_idx_train % params.valid_interval == 0
+ and not params.print_diagnostics
+ ):
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ params.batch_idx_train += 1
+
+ batch_size = len(batch["text"])
+
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ tokenizer=tokenizer,
+ return_tokens=True,
+ return_feature=True,
+ )
+
+ try:
+ with autocast("cuda", enabled=params.use_fp16):
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=True,
+ )
+
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ scaler.scale(loss).backward()
+
+ scheduler.step_batch(params.batch_idx_train)
+ # Use the number of hours of speech to adjust the learning rate
+ if params.lr_hours > 0:
+ scheduler.step_epoch(
+ params.batch_idx_train
+ * params.max_duration
+ * params.world_size
+ / 3600
+ )
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except RuntimeError as e:
+ if "out of memory" in str(e):
+ logging.info(f"out of memory error at rank {rank}")
+ # optimizer.zero_grad()
+ # duration_optimizer.zero_grad()
+ torch.cuda.empty_cache()
+ raise
+ continue
+ else:
+ logging.info(f"Caught exception : {e}.")
+ save_bad_model()
+ raise
+ except Exception as e:
+ logging.info(f"Caught exception : {e}.")
+ save_bad_model()
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 1024.0 or (
+ cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
+ ):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if params.batch_idx_train % params.log_interval == 0:
+ cur_lr = max(scheduler.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
+ f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ f"cur_lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ tokenizer: TokenizerEmilia,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+
+ model.eval()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ tokens, features, features_lens = prepare_input(
+ params=params,
+ batch=batch,
+ device=device,
+ tokenizer=tokenizer,
+ return_tokens=True,
+ return_feature=True,
+ )
+
+ loss, loss_info = compute_fbank_loss(
+ params=params,
+ model=model,
+ features=features,
+ features_lens=features_lens,
+ tokens=tokens,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ os.makedirs(f"{params.exp_dir}/fbank", exist_ok=True)
+
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ if params.dataset == "emilia":
+ tokenizer = TokenizerEmilia(
+ token_file=params.token_file, token_type=params.token_type
+ )
+ elif params.dataset == "libritts":
+ tokenizer = TokenizerLibriTTS(
+ token_file=params.token_file, token_type=params.token_type
+ )
+ params.vocab_size = tokenizer.vocab_size
+ params.pad_id = tokenizer.pad_id
+
+ params.device = device
+
+ logging.info(params)
+
+ logging.info("About to create model")
+
+ model = get_model(params)
+ if params.checkpoint is not None:
+ logging.info(f"Loading pre-trained model from {params.checkpoint}")
+ _ = load_checkpoint(filename=params.checkpoint, model=model, strict=True)
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of parameters : {num_param}")
+
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+ assert params.start_epoch > 0, params.start_epoch
+ if params.start_epoch > 1:
+ checkpoints = resume_checkpoint(params=params, model=model, model_avg=model_avg)
+
+ model = model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer = ScaledAdam(
+ get_parameter_groups_with_lrs(
+ model,
+ lr=params.base_lr,
+ include_names=True,
+ ),
+ lr=params.base_lr, # should have no effect
+ clipping_scale=2.0,
+ )
+
+ assert params.lr_hours >= 0
+ if params.lr_hours > 0:
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_hours)
+ else:
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ scaler = GradScaler("cuda", enabled=params.use_fp16)
+
+ if params.start_epoch > 1 and checkpoints is not None:
+ # load state_dict for optimizers
+ if "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ # load state_dict for schedulers
+ if "scheduler" in checkpoints:
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ def remove_short_and_long_utt_emilia(c: Cut):
+ if c.duration < 1.0 or c.duration > 30.0:
+ return False
+ return True
+
+ def remove_short_and_long_utt_libritts(c: Cut):
+ if c.duration < 1.0 or c.duration > 20.0:
+ return False
+ return True
+
+ datamodule = TtsDataModule(args)
+ if params.dataset == "emilia":
+ train_cuts = CutSet.mux(
+ datamodule.train_emilia_EN_cuts(),
+ datamodule.train_emilia_ZH_cuts(),
+ weights=[46000, 49000],
+ )
+ train_cuts = train_cuts.filter(remove_short_and_long_utt_emilia)
+ dev_cuts = CutSet.mux(
+ datamodule.dev_emilia_EN_cuts(),
+ datamodule.dev_emilia_ZH_cuts(),
+ weights=[0.5, 0.5],
+ )
+ elif params.dataset == "libritts":
+ train_cuts = datamodule.train_libritts_cuts()
+ train_cuts = train_cuts.filter(remove_short_and_long_utt_libritts)
+ dev_cuts = datamodule.dev_libritts_cuts()
+
+ train_dl = datamodule.train_dataloaders(train_cuts)
+
+ valid_dl = datamodule.dev_dataloaders(dev_cuts)
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ logging.info(f"Start epoch {epoch}")
+
+ if params.lr_hours == 0:
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ params.cur_epoch = epoch
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ tokenizer=tokenizer,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint(
+ filename=filename,
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if rank == 0:
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ TtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ main()
diff --git a/egs/zipvoice/zipvoice/tts_datamodule.py b/egs/zipvoice/zipvoice/tts_datamodule.py
new file mode 100644
index 000000000..e8ea7a4eb
--- /dev/null
+++ b/egs/zipvoice/zipvoice/tts_datamodule.py
@@ -0,0 +1,456 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
+# Zengwei Yao,
+# Zengrui Jin,
+# Han Zhu,
+# Wei Kang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, Sequence, Union
+
+import torch
+from feature import TorchAudioFbank, TorchAudioFbankConfig
+from lhotse import CutSet, load_manifest_lazy, validate
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ DynamicBucketingSampler,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+)
+from lhotse.dataset.collation import collate_audio
+from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
+ BatchIO,
+ OnTheFlyFeatures,
+)
+from lhotse.utils import fix_random_seed, ifnone
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+SAMPLING_RATE = 24000
+
+
+class TtsDataModule:
+ """
+ DataModule for tts experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="TTS data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/fbank_emilia"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=100,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=8,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to create train dataset")
+ train = SpeechSynthesisDataset(
+ return_text=True,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ sampling_rate = SAMPLING_RATE
+ config = TorchAudioFbankConfig(
+ sampling_rate=sampling_rate,
+ n_mels=100,
+ n_fft=1024,
+ hop_length=256,
+ )
+ train = SpeechSynthesisDataset(
+ return_text=True,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ buffer_size=self.args.num_buckets * 2000,
+ shuffle_buffer_size=self.args.num_buckets * 5000,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ sampling_rate = SAMPLING_RATE
+ config = TorchAudioFbankConfig(
+ sampling_rate=sampling_rate,
+ n_mels=100,
+ n_fft=1024,
+ hop_length=256,
+ )
+ validate = SpeechSynthesisDataset(
+ return_text=True,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = SpeechSynthesisDataset(
+ return_text=True,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ dev_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create valid dataloader")
+ dev_dl = DataLoader(
+ validate,
+ sampler=dev_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return dev_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.info("About to create test dataset")
+ if self.args.on_the_fly_feats:
+ sampling_rate = SAMPLING_RATE
+ config = TorchAudioFbankConfig(
+ sampling_rate=sampling_rate,
+ n_mels=100,
+ n_fft=1024,
+ hop_length=256,
+ )
+ test = SpeechSynthesisDataset(
+ return_text=True,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)),
+ return_cuts=self.args.return_cuts,
+ return_audio=True,
+ )
+ else:
+ test = SpeechSynthesisDataset(
+ return_text=True,
+ return_tokens=True,
+ return_spk_ids=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ return_audio=True,
+ )
+ test_sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=test_sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_emilia_EN_cuts(self) -> CutSet:
+ logging.info("About to get train the EN subset")
+ return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_EN.jsonl.gz")
+
+ @lru_cache()
+ def train_emilia_ZH_cuts(self) -> CutSet:
+ logging.info("About to get train the ZH subset")
+ return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_ZH.jsonl.gz")
+
+ @lru_cache()
+ def dev_emilia_EN_cuts(self) -> CutSet:
+ logging.info("About to get dev the EN subset")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "emilia_cuts_EN-dev.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_emilia_ZH_cuts(self) -> CutSet:
+ logging.info("About to get dev the ZH subset")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "emilia_cuts_ZH-dev.jsonl.gz"
+ )
+
+ @lru_cache()
+ def train_libritts_cuts(self) -> CutSet:
+ logging.info(
+ "About to get the shuffled train-clean-100, \
+ train-clean-360 and train-other-500 cuts"
+ )
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libritts_cuts_with_tokens_train-all-shuf.jsonl.gz"
+ )
+
+ @lru_cache()
+ def dev_libritts_cuts(self) -> CutSet:
+ logging.info("About to get dev-clean cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "libritts_cuts_with_tokens_dev-clean.jsonl.gz"
+ )
+
+
+class SpeechSynthesisDataset(torch.utils.data.Dataset):
+ """
+ The PyTorch Dataset for the speech synthesis task.
+ Each item in this dataset is a dict of:
+
+ .. code-block::
+
+ {
+ 'audio': (B x NumSamples) float tensor
+ 'features': (B x NumFrames x NumFeatures) float tensor
+ 'audio_lens': (B, ) int tensor
+ 'features_lens': (B, ) int tensor
+ 'text': List[str] of len B # when return_text=True
+ 'tokens': List[List[str]] # when return_tokens=True
+ 'speakers': List[str] of len B # when return_spk_ids=True
+ 'cut': List of Cuts # when return_cuts=True
+ }
+ """
+
+ def __init__(
+ self,
+ cut_transforms: List[Callable[[CutSet], CutSet]] = None,
+ feature_input_strategy: BatchIO = PrecomputedFeatures(),
+ feature_transforms: Union[Sequence[Callable], Callable] = None,
+ return_text: bool = True,
+ return_tokens: bool = False,
+ return_spk_ids: bool = False,
+ return_cuts: bool = False,
+ return_audio: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.cut_transforms = ifnone(cut_transforms, [])
+ self.feature_input_strategy = feature_input_strategy
+
+ self.return_text = return_text
+ self.return_tokens = return_tokens
+ self.return_spk_ids = return_spk_ids
+ self.return_cuts = return_cuts
+ self.return_audio = return_audio
+
+ if feature_transforms is None:
+ feature_transforms = []
+ elif not isinstance(feature_transforms, Sequence):
+ feature_transforms = [feature_transforms]
+
+ assert all(
+ isinstance(transform, Callable) for transform in feature_transforms
+ ), "Feature transforms must be Callable"
+ self.feature_transforms = feature_transforms
+
+ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
+ validate_for_tts(cuts)
+
+ for transform in self.cut_transforms:
+ cuts = transform(cuts)
+
+ features, features_lens = self.feature_input_strategy(cuts)
+
+ for transform in self.feature_transforms:
+ features = transform(features)
+
+ batch = {
+ "features": features,
+ "features_lens": features_lens,
+ }
+
+ if self.return_audio:
+ audio, audio_lens = collate_audio(cuts)
+ batch["audio"] = audio
+ batch["audio_lens"] = audio_lens
+
+ if self.return_text:
+ # use normalized text
+ text = [cut.supervisions[0].normalized_text for cut in cuts]
+ batch["text"] = text
+
+ if self.return_tokens:
+ tokens = [cut.tokens for cut in cuts]
+ batch["tokens"] = tokens
+
+ if self.return_spk_ids:
+ batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
+
+ if self.return_cuts:
+ batch["cut"] = [cut for cut in cuts]
+
+ return batch
+
+
+def validate_for_tts(cuts: CutSet) -> None:
+ validate(cuts)
+ for cut in cuts:
+ assert (
+ len(cut.supervisions) == 1
+ ), "Only the Cuts with single supervision are supported."
diff --git a/egs/zipvoice/zipvoice/utils.py b/egs/zipvoice/zipvoice/utils.py
new file mode 100644
index 000000000..4092d0ae4
--- /dev/null
+++ b/egs/zipvoice/zipvoice/utils.py
@@ -0,0 +1,219 @@
+from typing import Any, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+
+class AttributeDict(dict):
+ def __getattr__(self, key):
+ if key in self:
+ return self[key]
+ raise AttributeError(f"No such attribute '{key}'")
+
+ def __setattr__(self, key, value):
+ self[key] = value
+
+ def __delattr__(self, key):
+ if key in self:
+ del self[key]
+ return
+ raise AttributeError(f"No such attribute '{key}'")
+
+
+def prepare_input(
+ params: AttributeDict,
+ batch: dict,
+ device: torch.device,
+ tokenizer: Optional[Any] = None,
+ return_tokens: bool = False,
+ return_feature: bool = False,
+ return_audio: bool = False,
+ return_prompt: bool = False,
+):
+ """
+ Parse the features and targets of the current batch.
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ sp:
+ Used to convert text to bpe tokens.
+ device:
+ The device of Tensor.
+ """
+ return_list = []
+
+ if return_tokens:
+ assert tokenizer is not None
+
+ if params.token_type == "phone":
+ tokens = tokenizer.tokens_to_token_ids(batch["tokens"])
+ else:
+ tokens = tokenizer.texts_to_token_ids(batch["text"])
+ return_list += [tokens]
+
+ if return_feature:
+ features = batch["features"].to(device)
+ features_lens = batch["features_lens"].to(device)
+ return_list += [features * params.feat_scale, features_lens]
+
+ if return_audio:
+ return_list += [batch["audio"], batch["audio_lens"]]
+
+ if return_prompt:
+ if return_tokens:
+ if params.token_type == "phone":
+ prompt_tokens = tokenizer.tokens_to_token_ids(batch["prompt"]["tokens"])
+ else:
+ prompt_tokens = tokenizer.texts_to_token_ids(batch["prompt"]["text"])
+ return_list += [prompt_tokens]
+ if return_feature:
+ prompt_features = batch["prompt"]["features"].to(device)
+ prompt_features_lens = batch["prompt"]["features_lens"].to(device)
+ return_list += [prompt_features * params.feat_scale, prompt_features_lens]
+ if return_audio:
+ return_list += [batch["prompt"]["audio"], batch["prompt"]["audio_lens"]]
+
+ return return_list
+
+
+def prepare_avg_tokens_durations(features_lens, tokens_lens):
+ tokens_durations = []
+ for i in range(len(features_lens)):
+ utt_duration = features_lens[i]
+ avg_token_duration = utt_duration // tokens_lens[i]
+ tokens_durations.append([avg_token_duration] * tokens_lens[i])
+ return tokens_durations
+
+
+def pad_labels(y: List[List[int]], pad_id: int, device: torch.device):
+ """
+ Pad the transcripts to the same length with zeros.
+
+ Args:
+ y: the transcripts, which is a list of a list
+
+ Returns:
+ Return a Tensor of padded transcripts.
+ """
+ y = [l + [pad_id] for l in y]
+ length = max([len(l) for l in y])
+ y = [l + [pad_id] * (length - len(l)) for l in y]
+ return torch.tensor(y, dtype=torch.int64, device=device)
+
+
+def get_tokens_index(durations: List[List[int]], num_frames: int) -> torch.Tensor:
+ """
+ Gets position in the transcript for each frame, i.e. the position
+ in the symbol-sequence to look up.
+
+ Args:
+ durations:
+ Duration of each token in transcripts.
+ num_frames:
+ The maximum frame length of the current batch.
+
+ Returns:
+ Return a Tensor of shape (batch_size, num_frames)
+ """
+ durations = [x + [num_frames - sum(x)] for x in durations]
+ batch_size = len(durations)
+ ans = torch.zeros(batch_size, num_frames, dtype=torch.int64)
+ for b in range(batch_size):
+ this_dur = durations[b]
+ cur_frame = 0
+ for i, d in enumerate(this_dur):
+ ans[b, cur_frame : cur_frame + d] = i
+ cur_frame += d
+ assert cur_frame == num_frames, (cur_frame, num_frames)
+ return ans
+
+
+def to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+
+def get_adjusted_batch_count(params: AttributeDict) -> float:
+ # returns the number of batches we would have used so far if we had used the reference
+ # duration. This is for purposes of set_batch_count().
+ return (
+ params.batch_idx_train
+ * (params.max_duration * params.world_size)
+ / params.ref_duration
+ )
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for name, module in model.named_modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+ if hasattr(module, "name"):
+ module.name = name
+
+
+def condition_time_mask(
+ features_lens: torch.Tensor, mask_percent: Tuple[float, float], max_len: int = 0
+) -> torch.Tensor:
+ """
+ Apply Time masking.
+ Args:
+ features_lens:
+ input tensor of shape ``(B)``
+ mask_size:
+ the width size for masking.
+ max_len:
+ the maximum length of the mask.
+ Returns:
+ Return a 2-D bool tensor (B, T), where masked positions
+ are filled with `True` and non-masked positions are
+ filled with `False`.
+ """
+ mask_size = (
+ torch.zeros_like(features_lens, dtype=torch.float32).uniform_(*mask_percent)
+ * features_lens
+ ).to(torch.int64)
+ mask_starts = (
+ torch.rand_like(mask_size, dtype=torch.float32) * (features_lens - mask_size)
+ ).to(torch.int64)
+ mask_ends = mask_starts + mask_size
+ max_len = max(max_len, features_lens.max())
+ seq_range = torch.arange(0, max_len, device=features_lens.device)
+ mask = (seq_range[None, :] >= mask_starts[:, None]) & (
+ seq_range[None, :] < mask_ends[:, None]
+ )
+ return mask
+
+
+def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
+ """
+ Args:
+ lengths:
+ A 1-D tensor containing sentence lengths.
+ max_len:
+ The length of masks.
+ Returns:
+ Return a 2-D bool tensor, where masked positions
+ are filled with `True` and non-masked positions are
+ filled with `False`.
+
+ >>> lengths = torch.tensor([1, 3, 2, 5])
+ >>> make_pad_mask(lengths)
+ tensor([[False, True, True, True, True],
+ [False, False, False, True, True],
+ [False, False, True, True, True],
+ [False, False, False, False, False]])
+ """
+ assert lengths.ndim == 1, lengths.ndim
+ max_len = max(max_len, lengths.max())
+ n = lengths.size(0)
+ seq_range = torch.arange(0, max_len, device=lengths.device)
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
+
+ return expaned_lengths >= lengths.unsqueeze(-1)
diff --git a/egs/zipvoice/zipvoice/zipformer.py b/egs/zipvoice/zipvoice/zipformer.py
new file mode 100644
index 000000000..190191cbb
--- /dev/null
+++ b/egs/zipvoice/zipvoice/zipformer.py
@@ -0,0 +1,1648 @@
+#!/usr/bin/env python3
+# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey,
+# Zengwei Yao,
+# Wei Kang
+# Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import logging
+import math
+import random
+from typing import Optional, Tuple, Union
+
+import torch
+from scaling import (
+ Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
+)
+from scaling import (
+ ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
+)
+from scaling import (
+ ActivationDropoutAndLinear,
+ Balancer,
+ BiasNorm,
+ Dropout2,
+ FloatLike,
+ ScheduledFloat,
+ SwooshR,
+ Whiten,
+ limit_param_value,
+ penalize_abs_values_gt,
+ softmax,
+)
+from torch import Tensor, nn
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """Create sinusoidal timestep embeddings.
+
+ :param timesteps: shape of (N) or (N, T)
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim)
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
+ / half
+ )
+
+ if timesteps.dim() == 2:
+ timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N)
+
+ args = timesteps[..., None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1)
+ return embedding
+
+
+class TTSZipformer(nn.Module):
+ """
+ Args:
+
+ Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length
+ as downsampling_factor if they are single ints or one-element tuples. The length of
+ downsampling_factor defines the number of stacks.
+
+ downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
+ Note: this is in addition to the downsampling factor of 2 that is applied in
+ the frontend (self.encoder_embed).
+ encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
+ encoder stack.
+ num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
+ query_head_dim (int or Tuple[int]): dimension of query and key per attention
+ head: per stack, if a tuple..
+ pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per
+ attention head
+ value_head_dim (int or Tuple[int]): dimension of value in each attention head
+ num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
+ Must be at least 4.
+ feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
+ cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
+
+ pos_dim (int): the dimension of each positional-encoding vector prior to projection,
+ e.g. 128.
+
+ dropout (float): dropout rate
+ warmup_batches (float): number of batches to warm up over; this controls
+ dropout of encoder layers.
+ use_time_embed: (bool): if True, do not take time embedding as additional input.
+ time_embed_dim: (int): the dimension of the time embedding.
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ downsampling_factor: Tuple[int] = (2, 4),
+ num_encoder_layers: Union[int, Tuple[int]] = 4,
+ cnn_module_kernel: Union[int, Tuple[int]] = 31,
+ encoder_dim: int = 384,
+ query_head_dim: int = 24,
+ pos_head_dim: int = 4,
+ value_head_dim: int = 12,
+ num_heads: int = 8,
+ feedforward_dim: int = 1536,
+ pos_dim: int = 192,
+ dropout: FloatLike = None, # see code below for default
+ warmup_batches: float = 4000.0,
+ use_time_embed: bool = True,
+ time_embed_dim: int = 192,
+ use_guidance_scale_embed: bool = False,
+ guidance_scale_embed_dim: int = 192,
+ use_conv: bool = True,
+ ) -> None:
+ super(TTSZipformer, self).__init__()
+
+ if dropout is None:
+ dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
+
+ def _to_tuple(x):
+ """Converts a single int or a 1-tuple of an int to a tuple with the same length
+ as downsampling_factor"""
+ if isinstance(x, int):
+ x = (x,)
+ if len(x) == 1:
+ x = x * len(downsampling_factor)
+ else:
+ assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
+ return x
+
+ def _assert_downsampling_factor(factors):
+ """assert downsampling_factor follows u-net style"""
+ assert factors[0] == 1 and factors[-1] == 1
+
+ for i in range(1, len(factors) // 2 + 1):
+ assert factors[i] == factors[i - 1] * 2
+
+ for i in range(len(factors) // 2 + 1, len(factors)):
+ assert factors[i] * 2 == factors[i - 1]
+
+ _assert_downsampling_factor(downsampling_factor)
+ self.downsampling_factor = downsampling_factor # tuple
+ num_encoder_layers = _to_tuple(num_encoder_layers)
+ self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
+ self.encoder_dim = encoder_dim
+ self.num_encoder_layers = num_encoder_layers
+ self.query_head_dim = query_head_dim
+ self.value_head_dim = value_head_dim
+ self.num_heads = num_heads
+
+ self.use_time_embed = use_time_embed
+ self.use_guidance_scale_embed = use_guidance_scale_embed
+
+ self.time_embed_dim = time_embed_dim
+ if self.use_time_embed:
+ assert time_embed_dim != -1
+ else:
+ time_embed_dim = -1
+ self.guidance_scale_embed_dim = guidance_scale_embed_dim
+
+ self.in_proj = nn.Linear(in_dim, encoder_dim)
+ self.out_proj = nn.Linear(encoder_dim, out_dim)
+
+ # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
+ encoders = []
+
+ num_encoders = len(downsampling_factor)
+ for i in range(num_encoders):
+ encoder_layer = Zipformer2EncoderLayer(
+ embed_dim=encoder_dim,
+ pos_dim=pos_dim,
+ num_heads=num_heads,
+ query_head_dim=query_head_dim,
+ pos_head_dim=pos_head_dim,
+ value_head_dim=value_head_dim,
+ feedforward_dim=feedforward_dim,
+ use_conv=use_conv,
+ cnn_module_kernel=cnn_module_kernel[i],
+ dropout=dropout,
+ )
+
+ # For the segment of the warmup period, we let the Conv2dSubsampling
+ # layer learn something. Then we start to warm up the other encoders.
+ encoder = Zipformer2Encoder(
+ encoder_layer,
+ num_encoder_layers[i],
+ embed_dim=encoder_dim,
+ time_embed_dim=time_embed_dim,
+ pos_dim=pos_dim,
+ warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
+ warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
+ final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
+ )
+
+ if downsampling_factor[i] != 1:
+ encoder = DownsampledZipformer2Encoder(
+ encoder,
+ dim=encoder_dim,
+ downsample=downsampling_factor[i],
+ )
+
+ encoders.append(encoder)
+
+ self.encoders = nn.ModuleList(encoders)
+ if self.use_time_embed:
+ self.time_embed = nn.Sequential(
+ nn.Linear(time_embed_dim, time_embed_dim * 2),
+ SwooshR(),
+ nn.Linear(time_embed_dim * 2, time_embed_dim),
+ )
+ else:
+ self.time_embed = None
+
+ if self.use_guidance_scale_embed:
+ self.guidance_scale_embed = ScaledLinear(
+ guidance_scale_embed_dim, time_embed_dim, bias=False, initial_scale=0.1
+ )
+ else:
+ self.guidance_scale_embed = None
+
+ def forward(
+ self,
+ x: Tensor,
+ t: Optional[Tensor] = None,
+ padding_mask: Optional[Tensor] = None,
+ guidance_scale: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (batch_size, seq_len, feature_dim).
+ t:
+ A t tensor of shape (batch_size,) or (batch_size, seq_len)
+ padding_mask:
+ The mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+ Returns:
+ Return the output embeddings. its shape is (batch_size, output_seq_len, encoder_dim)
+ """
+ x = x.permute(1, 0, 2)
+ x = self.in_proj(x)
+
+ if t is not None:
+ assert t.dim() == 1 or t.dim() == 2, t.shape
+ time_emb = timestep_embedding(t, self.time_embed_dim)
+ if guidance_scale is not None:
+ assert (
+ guidance_scale.dim() == 1 or guidance_scale.dim() == 2
+ ), guidance_scale.shape
+ guidance_scale_emb = self.guidance_scale_embed(
+ timestep_embedding(guidance_scale, self.guidance_scale_embed_dim)
+ )
+ time_emb = time_emb + guidance_scale_emb
+ time_emb = self.time_embed(time_emb)
+ else:
+ time_emb = None
+
+ attn_mask = None
+
+ for i, module in enumerate(self.encoders):
+ x = module(
+ x,
+ time_emb=time_emb,
+ src_key_padding_mask=padding_mask,
+ attn_mask=attn_mask,
+ )
+ x = self.out_proj(x)
+ x = x.permute(1, 0, 2)
+ return x
+
+
+def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
+ return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x)
+
+
+class Zipformer2EncoderLayer(nn.Module):
+ """
+ Args:
+ embed_dim: the number of expected features in the input (required).
+ nhead: the number of heads in the multiheadattention models (required).
+ feedforward_dim: the dimension of the feedforward network model (required).
+ dropout: the dropout value (default=0.1).
+ cnn_module_kernel (int): Kernel size of convolution module (default=31).
+
+ Examples::
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
+ >>> src = torch.rand(10, 32, 512)
+ >>> pos_emb = torch.rand(32, 19, 512)
+ >>> out = encoder_layer(src, pos_emb)
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ pos_dim: int,
+ num_heads: int,
+ query_head_dim: int,
+ pos_head_dim: int,
+ value_head_dim: int,
+ feedforward_dim: int,
+ dropout: FloatLike = 0.1,
+ cnn_module_kernel: int = 31,
+ use_conv: bool = True,
+ attention_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
+ ),
+ conv_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
+ ),
+ const_attention_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.25), (4000.0, 0.025), default=0
+ ),
+ ff2_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
+ ),
+ ff3_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
+ ),
+ bypass_skip_rate: FloatLike = ScheduledFloat(
+ (0.0, 0.5), (4000.0, 0.02), default=0
+ ),
+ ) -> None:
+ super(Zipformer2EncoderLayer, self).__init__()
+ self.embed_dim = embed_dim
+
+ # self.bypass implements layer skipping as well as bypass; see its default values.
+ self.bypass = BypassModule(
+ embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
+ )
+ # bypass_mid is bypass used in the middle of the layer.
+ self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
+
+ # skip probability for dynamic modules (meaning: anything but feedforward).
+ self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
+ # an additional skip probability that applies to ConvModule to stop it from
+ # contributing too much early on.
+ self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
+
+ # ff2_skip_rate is to prevent the ff2 module from having output that's too big
+ # compared to its residual.
+ self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
+ self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
+
+ self.const_attention_rate = copy.deepcopy(const_attention_rate)
+
+ self.self_attn_weights = RelPositionMultiheadAttentionWeights(
+ embed_dim,
+ pos_dim=pos_dim,
+ num_heads=num_heads,
+ query_head_dim=query_head_dim,
+ pos_head_dim=pos_head_dim,
+ dropout=0.0,
+ )
+
+ self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
+
+ self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)
+
+ self.feed_forward1 = FeedforwardModule(
+ embed_dim, (feedforward_dim * 3) // 4, dropout
+ )
+
+ self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
+
+ self.feed_forward3 = FeedforwardModule(
+ embed_dim, (feedforward_dim * 5) // 4, dropout
+ )
+
+ self.nonlin_attention = NonlinAttention(
+ embed_dim, hidden_channels=3 * embed_dim // 4
+ )
+
+ self.use_conv = use_conv
+
+ if self.use_conv:
+ self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel)
+
+ self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel)
+
+ self.norm = BiasNorm(embed_dim)
+
+ self.balancer1 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ min_abs=0.2,
+ max_abs=4.0,
+ )
+
+ # balancer for output of NonlinAttentionModule
+ self.balancer_na = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
+ prob=0.05, # out of concern for memory usage
+ )
+
+ # balancer for output of feedforward2, prevent it from staying too
+ # small. give this a very small probability, even at the start of
+ # training, it's to fix a rare problem and it's OK to fix it slowly.
+ self.balancer_ff2 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
+ max_abs=2.0,
+ prob=0.05,
+ )
+
+ self.balancer_ff3 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=0.7,
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
+ max_abs=4.0,
+ prob=0.05,
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(4.0, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.balancer2 = Balancer(
+ embed_dim,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ min_abs=0.1,
+ max_abs=4.0,
+ )
+
+ def get_sequence_dropout_mask(
+ self, x: Tensor, dropout_rate: float
+ ) -> Optional[Tensor]:
+ if (
+ dropout_rate == 0.0
+ or not self.training
+ or torch.jit.is_scripting()
+ or torch.jit.is_tracing()
+ ):
+ return None
+ batch_size = x.shape[1]
+ mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
+ return mask
+
+ def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
+ """
+ Apply sequence-level dropout to x.
+ x shape: (seq_len, batch_size, embed_dim)
+ """
+ dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
+ if dropout_mask is None:
+ return x
+ else:
+ return x * dropout_mask
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ time_emb: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Pass the input through the encoder layer.
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
+ time_emb: the embedding representing the current timestep: shape (batch_size, embedding_dim)
+ or (seq_len, batch_size, embedding_dim) .
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+
+ Returns:
+ A tensor which has the same shape as src
+ """
+ src_orig = src
+
+ # dropout rate for non-feedforward submodules
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ attention_skip_rate = 0.0
+ else:
+ attention_skip_rate = (
+ float(self.attention_skip_rate) if self.training else 0.0
+ )
+
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ attn_weights = self.self_attn_weights(
+ src,
+ pos_emb=pos_emb,
+ attn_mask=attn_mask,
+ key_padding_mask=src_key_padding_mask,
+ )
+ if time_emb is not None:
+
+ src = src + time_emb
+
+ src = src + self.feed_forward1(src)
+
+ self_attn_dropout_mask = self.get_sequence_dropout_mask(
+ src, attention_skip_rate
+ )
+
+ selected_attn_weights = attn_weights[0:1]
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif self.training and random.random() < float(self.const_attention_rate):
+ # Make attention weights constant. The intention is to
+ # encourage these modules to do something similar to an
+ # averaging-over-time operation.
+ # only need the mask, can just use the 1st one and expand later
+ selected_attn_weights = selected_attn_weights[0:1]
+ selected_attn_weights = (selected_attn_weights > 0.0).to(
+ selected_attn_weights.dtype
+ )
+ selected_attn_weights = selected_attn_weights * (
+ 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
+ )
+
+ na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
+
+ src = src + (
+ na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
+ )
+
+ self_attn = self.self_attn1(src, attn_weights)
+
+ src = src + (
+ self_attn
+ if self_attn_dropout_mask is None
+ else self_attn * self_attn_dropout_mask
+ )
+
+ if self.use_conv:
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ conv_skip_rate = 0.0
+ else:
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
+
+ if time_emb is not None:
+ src = src + time_emb
+
+ src = src + self.sequence_dropout(
+ self.conv_module1(
+ src,
+ src_key_padding_mask=src_key_padding_mask,
+ ),
+ conv_skip_rate,
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ ff2_skip_rate = 0.0
+ else:
+ ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
+ )
+
+ # bypass in the middle of the layer.
+ src = self.bypass_mid(src_orig, src)
+
+ self_attn = self.self_attn2(src, attn_weights)
+
+ src = src + (
+ self_attn
+ if self_attn_dropout_mask is None
+ else self_attn * self_attn_dropout_mask
+ )
+
+ if self.use_conv:
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ conv_skip_rate = 0.0
+ else:
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
+
+ if time_emb is not None:
+ src = src + time_emb
+
+ src = src + self.sequence_dropout(
+ self.conv_module2(
+ src,
+ src_key_padding_mask=src_key_padding_mask,
+ ),
+ conv_skip_rate,
+ )
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ ff3_skip_rate = 0.0
+ else:
+ ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
+ src = src + self.sequence_dropout(
+ self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
+ )
+
+ src = self.balancer1(src)
+ src = self.norm(src)
+
+ src = self.bypass(src_orig, src)
+
+ src = self.balancer2(src)
+ src = self.whiten(src)
+
+ return src
+
+
+class Zipformer2Encoder(nn.Module):
+ r"""Zipformer2Encoder is a stack of N encoder layers
+
+ Args:
+ encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
+ num_layers: the number of sub-encoder-layers in the encoder (required).
+ pos_dim: the dimension for the relative positional encoding
+
+ Examples::
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
+ >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
+ >>> src = torch.rand(10, 32, 512)
+ >>> out = zipformer_encoder(src)
+ """
+
+ def __init__(
+ self,
+ encoder_layer: nn.Module,
+ num_layers: int,
+ embed_dim: int,
+ time_embed_dim: int,
+ pos_dim: int,
+ warmup_begin: float,
+ warmup_end: float,
+ initial_layerdrop_rate: float = 0.5,
+ final_layerdrop_rate: float = 0.05,
+ ) -> None:
+ super().__init__()
+ self.encoder_pos = CompactRelPositionalEncoding(
+ pos_dim, dropout_rate=0.15, length_factor=1.0
+ )
+ if time_embed_dim != -1:
+ self.time_emb = nn.Sequential(
+ SwooshR(),
+ nn.Linear(time_embed_dim, embed_dim),
+ )
+ else:
+ self.time_emb = None
+
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+ )
+ self.num_layers = num_layers
+
+ assert 0 <= warmup_begin <= warmup_end
+
+ delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
+ cur_begin = warmup_begin # interpreted as a training batch index
+ for i in range(num_layers):
+ cur_end = cur_begin + delta
+ self.layers[i].bypass.skip_rate = ScheduledFloat(
+ (cur_begin, initial_layerdrop_rate),
+ (cur_end, final_layerdrop_rate),
+ default=0.0,
+ )
+ cur_begin = cur_end
+
+ def forward(
+ self,
+ src: Tensor,
+ time_emb: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ the embedding representing the current timestep: shape (batch_size, embedding_dim)
+ or (seq_len, batch_size, embedding_dim) .
+ time_emb: the embedding representing the current timestep: shape (batch_size, embedding_dim)
+ or (seq_len, batch_size, embedding_dim) .
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+
+ Returns: a Tensor with the same shape as src.
+ """
+ pos_emb = self.encoder_pos(src)
+ if self.time_emb is not None:
+ assert time_emb is not None
+ time_emb = self.time_emb(time_emb)
+ else:
+ assert time_emb is None
+
+ output = src
+
+ for i, mod in enumerate(self.layers):
+ output = mod(
+ output,
+ pos_emb,
+ time_emb=time_emb,
+ attn_mask=attn_mask,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+
+ return output
+
+
+class BypassModule(nn.Module):
+ """
+ An nn.Module that implements a learnable bypass scale, and also randomized per-sequence
+ layer-skipping. The bypass is limited during early stages of training to be close to
+ "straight-through", i.e. to not do the bypass operation much initially, in order to
+ force all the modules to learn something.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ skip_rate: FloatLike = 0.0,
+ straight_through_rate: FloatLike = 0.0,
+ scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
+ scale_max: FloatLike = 1.0,
+ ):
+ super().__init__()
+ self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
+ self.skip_rate = copy.deepcopy(skip_rate)
+ self.straight_through_rate = copy.deepcopy(straight_through_rate)
+ self.scale_min = copy.deepcopy(scale_min)
+ self.scale_max = copy.deepcopy(scale_max)
+
+ def _get_bypass_scale(self, batch_size: int):
+ # returns bypass-scale of shape (num_channels,),
+ # or (batch_size, num_channels,). This is actually the
+ # scale on the non-residual term, so 0 corresponds to bypassing
+ # this module.
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
+ return self.bypass_scale
+ else:
+ ans = limit_param_value(
+ self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max)
+ )
+ skip_rate = float(self.skip_rate)
+ if skip_rate != 0.0:
+ mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
+ ans = ans * mask
+ # now ans is of shape (batch_size, num_channels), and is zero for sequences
+ # on which we have randomly chosen to do layer-skipping.
+ straight_through_rate = float(self.straight_through_rate)
+ if straight_through_rate != 0.0:
+ mask = (
+ torch.rand((batch_size, 1), device=ans.device)
+ < straight_through_rate
+ )
+ ans = torch.maximum(ans, mask.to(ans.dtype))
+ return ans
+
+ def forward(self, src_orig: Tensor, src: Tensor):
+ """
+ Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
+ Returns: something with the same shape as src and src_orig
+ """
+ bypass_scale = self._get_bypass_scale(src.shape[1])
+ return src_orig + (src - src_orig) * bypass_scale
+
+
+class DownsampledZipformer2Encoder(nn.Module):
+ r"""
+ DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
+ after convolutional downsampling, and then upsampled again at the output, and combined
+ with the origin input, so that the output has the same shape as the input.
+ """
+
+ def __init__(self, encoder: nn.Module, dim: int, downsample: int):
+ super(DownsampledZipformer2Encoder, self).__init__()
+ self.downsample_factor = downsample
+ self.downsample = SimpleDownsample(downsample)
+ self.num_layers = encoder.num_layers
+ self.encoder = encoder
+ self.upsample = SimpleUpsample(downsample)
+ self.out_combiner = BypassModule(dim, straight_through_rate=0)
+
+ def forward(
+ self,
+ src: Tensor,
+ time_emb: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Downsample, go through encoder, upsample.
+
+ Args:
+ src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+ time_emb: the embedding representing the current timestep: shape (batch_size, embedding_dim)
+ or (seq_len, batch_size, embedding_dim) .
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
+ by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+ interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+ True means masked position. May be None.
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
+ masked position. May be None.
+
+ Returns: a Tensor with the same shape as src.
+ """
+ src_orig = src
+ src = self.downsample(src)
+ ds = self.downsample_factor
+ if time_emb is not None and time_emb.dim() == 3:
+ time_emb = time_emb[::ds]
+ if attn_mask is not None:
+ attn_mask = attn_mask[::ds, ::ds]
+ if src_key_padding_mask is not None:
+ src_key_padding_mask = src_key_padding_mask[..., ::ds]
+
+ src = self.encoder(
+ src,
+ time_emb=time_emb,
+ attn_mask=attn_mask,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ src = self.upsample(src)
+ # remove any extra frames that are not a multiple of downsample_factor
+ src = src[: src_orig.shape[0]]
+
+ return self.out_combiner(src_orig, src)
+
+
+class SimpleDownsample(torch.nn.Module):
+ """
+ Does downsampling with attention, by weighted sum.
+ """
+
+ def __init__(self, downsample: int):
+ super(SimpleDownsample, self).__init__()
+
+ self.bias = nn.Parameter(torch.zeros(downsample))
+
+ self.name = None # will be set from training code
+
+ self.downsample = downsample
+
+ def forward(self, src: Tensor) -> Tensor:
+ """
+ x: (seq_len, batch_size, in_channels)
+ Returns a tensor of shape
+ ( (seq_len+downsample-1)//downsample, batch_size, channels)
+ """
+ (seq_len, batch_size, in_channels) = src.shape
+ ds = self.downsample
+ d_seq_len = (seq_len + ds - 1) // ds
+
+ # Pad to an exact multiple of self.downsample
+ # right-pad src, repeating the last element.
+ pad = d_seq_len * ds - seq_len
+ src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
+ src = torch.cat((src, src_extra), dim=0)
+ assert src.shape[0] == d_seq_len * ds
+
+ src = src.reshape(d_seq_len, ds, batch_size, in_channels)
+
+ weights = self.bias.softmax(dim=0)
+ # weights: (downsample, 1, 1)
+ weights = weights.unsqueeze(-1).unsqueeze(-1)
+
+ # ans1 is the first `in_channels` channels of the output
+ ans = (src * weights).sum(dim=1)
+
+ return ans
+
+
+class SimpleUpsample(torch.nn.Module):
+ """
+ A very simple form of upsampling that just repeats the input.
+ """
+
+ def __init__(self, upsample: int):
+ super(SimpleUpsample, self).__init__()
+ self.upsample = upsample
+
+ def forward(self, src: Tensor) -> Tensor:
+ """
+ x: (seq_len, batch_size, num_channels)
+ Returns a tensor of shape
+ ( (seq_len*upsample), batch_size, num_channels)
+ """
+ upsample = self.upsample
+ (seq_len, batch_size, num_channels) = src.shape
+ src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
+ src = src.reshape(seq_len * upsample, batch_size, num_channels)
+ return src
+
+
+class CompactRelPositionalEncoding(torch.nn.Module):
+ """
+ Relative positional encoding module. This version is "compact" meaning it is able to encode
+ the important information about the relative position in a relatively small number of dimensions.
+ The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001)
+ make very little difference to the embedding. Such differences were potentially important
+ when encoding absolute position, but not important when encoding relative position because there
+ is now no need to compare two large offsets with each other.
+
+ Our embedding works by projecting the interval [-infinity,infinity] to a finite interval
+ using the atan() function, before doing the Fourier transform of that fixed interval. The
+ atan() function would compress the "long tails" too small,
+ making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic
+ function to compress large offsets to a smaller range before applying atan().
+ Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long
+ as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim)
+
+
+ Args:
+ embed_dim: Embedding dimension.
+ dropout_rate: Dropout rate.
+ max_len: Maximum input length: just a heuristic for initialization.
+ length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
+ less weight to small differences of offset near the origin.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ dropout_rate: FloatLike,
+ max_len: int = 1000,
+ length_factor: float = 1.0,
+ ) -> None:
+ """Construct a CompactRelPositionalEncoding object."""
+ super(CompactRelPositionalEncoding, self).__init__()
+ self.embed_dim = embed_dim
+ assert embed_dim % 2 == 0, embed_dim
+ self.dropout = Dropout2(dropout_rate)
+ self.pe = None
+ assert length_factor >= 1.0, length_factor
+ self.length_factor = length_factor
+ self.extend_pe(torch.tensor(0.0).expand(max_len))
+
+ def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
+ """Reset the positional encodings."""
+ T = x.size(0) + left_context_len
+
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(0) >= T * 2 - 1:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+
+ # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
+ x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
+
+ freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
+
+ # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution
+ # for small time offsets but less resolution for large time offsets.
+ compression_length = self.embed_dim**0.5
+ # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity;
+ # but it does so more slowly than T for large absolute values of T.
+ # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which
+ # is important.
+ x_compressed = (
+ compression_length
+ * x.sign()
+ * ((x.abs() + compression_length).log() - math.log(compression_length))
+ )
+
+ # if self.length_factor == 1.0, then length_scale is chosen so that the
+ # FFT can exactly separate points close to the origin (T == 0). So this
+ # part of the formulation is not really heuristic.
+ # But empirically, for ASR at least, length_factor > 1.0 seems to work better.
+ length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
+
+ # note for machine implementations: if atan is not available, we can use:
+ # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
+ # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
+ x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
+
+ cosines = (x_atan * freqs).cos()
+ sines = (x_atan * freqs).sin()
+
+ pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
+ pe[:, 0::2] = cosines
+ pe[:, 1::2] = sines
+ pe[:, -1] = 1.0 # for bias.
+
+ self.pe = pe.to(dtype=x.dtype)
+
+ def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
+ """Create positional encoding.
+
+ Args:
+ x (Tensor): Input tensor (time, batch, `*`).
+ left_context_len: (int): Length of cached left context.
+
+ Returns:
+ positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
+ """
+ self.extend_pe(x, left_context_len)
+ x_size_left = x.size(0) + left_context_len
+ # length of positive side: x.size(0) + left_context_len
+ # length of negative side: x.size(0)
+ pos_emb = self.pe[
+ self.pe.size(0) // 2
+ - x_size_left
+ + 1 : self.pe.size(0) // 2 # noqa E203
+ + x.size(0),
+ :,
+ ]
+ pos_emb = pos_emb.unsqueeze(0)
+ return self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttentionWeights(nn.Module):
+ r"""Module that computes multi-head attention weights with relative position encoding.
+ Various other modules consume the resulting attention weights: see, for example, the
+ SimpleAttention module which allows you to compute conventional attention.
+
+ This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context",
+ we have to write up the differences.
+
+
+ Args:
+ embed_dim: number of channels at the input to this module, e.g. 256
+ pos_dim: dimension of the positional encoding vectors, e.g. 128.
+ num_heads: number of heads to compute weights for, e.g. 8
+ query_head_dim: dimension of the query (and key), per head. e.g. 24.
+ pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
+ dropout: dropout probability for attn_output_weights. Default: 0.0.
+ pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
+ any given call to forward(), in training time.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ pos_dim: int,
+ num_heads: int,
+ query_head_dim: int,
+ pos_head_dim: int,
+ dropout: float = 0.0,
+ pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
+ ) -> None:
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.query_head_dim = query_head_dim
+ self.pos_head_dim = pos_head_dim
+ self.dropout = dropout
+ self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
+ self.name = None # will be overwritten in training code; for diagnostics.
+
+ key_head_dim = query_head_dim
+ in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
+
+ # the initial_scale is supposed to take over the "scaling" factor of
+ # head_dim ** -0.5 that has been used in previous forms of attention,
+ # dividing it between the query and key. Note: this module is intended
+ # to be used with the ScaledAdam optimizer; with most other optimizers,
+ # it would be necessary to apply the scaling factor in the forward function.
+ self.in_proj = ScaledLinear(
+ embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
+ )
+
+ self.whiten_keys = Whiten(
+ num_groups=num_heads,
+ whitening_limit=_whitening_schedule(3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.025,
+ )
+
+ # add a balancer for the keys that runs with very small probability, and
+ # tries to enforce that all dimensions have mean around zero. The
+ # weights produced by this module are invariant to adding a constant to
+ # the keys, so the derivative of the bias is mathematically zero; but
+ # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
+ # bias because the small numerical roundoff tends to have a non-random
+ # sign. This module is intended to prevent that. Use a very small
+ # probability; that should be sufficient to fix the problem.
+ self.balance_keys = Balancer(
+ key_head_dim * num_heads,
+ channel_dim=-1,
+ min_positive=0.4,
+ max_positive=0.6,
+ min_abs=0.0,
+ max_abs=100.0,
+ prob=0.025,
+ )
+
+ # linear transformation for positional encoding.
+ self.linear_pos = ScaledLinear(
+ pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05
+ )
+
+ # the following are for diagnostics only, see --print-diagnostics option
+ self.copy_pos_query = Identity()
+ self.copy_query = Identity()
+
+ def forward(
+ self,
+ x: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""
+ Args:
+ x: input of shape (seq_len, batch_size, embed_dim)
+ pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
+ key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
+ are True in this mask will be ignored as sources in the attention weighting.
+ attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
+ interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
+ saying which positions are allowed to attend to which other positions.
+ Returns:
+ a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
+ interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
+ """
+ x = self.in_proj(x)
+ query_head_dim = self.query_head_dim
+ pos_head_dim = self.pos_head_dim
+ num_heads = self.num_heads
+
+ seq_len, batch_size, _ = x.shape
+
+ query_dim = query_head_dim * num_heads
+
+ # self-attention
+ q = x[..., 0:query_dim]
+ k = x[..., query_dim : 2 * query_dim]
+ # p is the position-encoding query
+ p = x[..., 2 * query_dim :]
+ assert p.shape[-1] == num_heads * pos_head_dim, (
+ p.shape[-1],
+ num_heads,
+ pos_head_dim,
+ )
+
+ q = self.copy_query(q) # for diagnostics only, does nothing.
+ k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
+ p = self.copy_pos_query(p) # for diagnostics only, does nothing.
+
+ q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
+ p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
+ k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
+
+ # time1 refers to target, time2 refers to source.
+ q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
+ p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
+ k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
+
+ attn_scores = torch.matmul(q, k)
+
+ use_pos_scores = False
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ # We can't put random.random() in the same line
+ use_pos_scores = True
+ elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
+ use_pos_scores = True
+
+ if use_pos_scores:
+ pos_emb = self.linear_pos(pos_emb)
+ seq_len2 = 2 * seq_len - 1
+ pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
+ 2, 0, 3, 1
+ )
+ # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
+
+ # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
+ # [where seq_len2 represents relative position.]
+ pos_scores = torch.matmul(p, pos_emb)
+ # the following .as_strided() expression converts the last axis of pos_scores from relative
+ # to absolute position. I don't know whether I might have got the time-offsets backwards or
+ # not, but let this code define which way round it is supposed to be.
+ if torch.jit.is_tracing():
+ (num_heads, batch_size, time1, n) = pos_scores.shape
+ rows = torch.arange(start=time1 - 1, end=-1, step=-1)
+ cols = torch.arange(seq_len)
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+ indexes = rows + cols
+ pos_scores = pos_scores.reshape(-1, n)
+ pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
+ pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
+ else:
+ pos_scores = pos_scores.as_strided(
+ (num_heads, batch_size, seq_len, seq_len),
+ (
+ pos_scores.stride(0),
+ pos_scores.stride(1),
+ pos_scores.stride(2) - pos_scores.stride(3),
+ pos_scores.stride(3),
+ ),
+ storage_offset=pos_scores.stride(3) * (seq_len - 1),
+ )
+
+ attn_scores = attn_scores + pos_scores
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif self.training and random.random() < 0.1:
+ # This is a harder way of limiting the attention scores to not be
+ # too large. It incurs a penalty if any of them has an absolute
+ # value greater than 50.0. this should be outside the normal range
+ # of the attention scores. We use this mechanism instead of, say,
+ # something added to the loss function involving the entropy,
+ # because once the entropy gets very small gradients through the
+ # softmax can become very small, and we'd get zero derivatives. The
+ # choices of 1.0e-04 as the scale on the penalty makes this
+ # mechanism vulnerable to the absolute scale of the loss function,
+ # but we view this as a failsafe to avoid "implausible" parameter
+ # values rather than a regularization method that should be active
+ # under normal circumstances.
+ attn_scores = penalize_abs_values_gt(
+ attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
+ )
+
+ assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ if attn_mask is not None:
+ assert attn_mask.dtype == torch.bool
+ # use -1000 to avoid nan's where attn_mask and key_padding_mask make
+ # all scores zero. It's important that this be large enough that exp(-1000)
+ # is exactly zero, for reasons related to const_attention_rate, it
+ # compares the final weights with zero.
+ attn_scores = attn_scores.masked_fill(attn_mask, -1000)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (
+ batch_size,
+ seq_len,
+ ), key_padding_mask.shape
+ attn_scores = attn_scores.masked_fill(
+ key_padding_mask.unsqueeze(1),
+ -1000,
+ )
+
+ # We use our own version of softmax, defined in scaling.py, which should
+ # save a little of the memory used in backprop by, if we are in
+ # automatic mixed precision mode (amp / autocast), by only storing the
+ # half-precision output for backprop purposes.
+ attn_weights = softmax(attn_scores, dim=-1)
+
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ pass
+ elif random.random() < 0.001 and not self.training:
+ self._print_attn_entropy(attn_weights)
+
+ attn_weights = nn.functional.dropout(
+ attn_weights, p=self.dropout, training=self.training
+ )
+
+ return attn_weights
+
+ def _print_attn_entropy(self, attn_weights: Tensor):
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+ (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
+
+ with torch.no_grad():
+ with torch.amp.autocast("cuda", enabled=False):
+ attn_weights = attn_weights.to(torch.float32)
+ attn_weights_entropy = (
+ -((attn_weights + 1.0e-20).log() * attn_weights)
+ .sum(dim=-1)
+ .mean(dim=(1, 2))
+ )
+ logging.debug(
+ f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}"
+ )
+
+
+class SelfAttention(nn.Module):
+ """
+ The simplest possible attention module. This one works with already-computed attention
+ weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
+
+ Args:
+ embed_dim: the input and output embedding dimension
+ num_heads: the number of attention heads
+ value_head_dim: the value dimension per head
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ value_head_dim: int,
+ ) -> None:
+ super().__init__()
+ self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
+
+ self.out_proj = ScaledLinear(
+ num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ ) -> Tensor:
+ """
+ Args:
+ x: input tensor, of shape (seq_len, batch_size, embed_dim)
+ attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
+ with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
+ attn_weights.sum(dim=-1) == 1.
+ Returns:
+ a tensor with the same shape as x.
+ """
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, value_head_dim)
+ value_head_dim = x.shape[-1]
+
+ # todo: see whether there is benefit in overriding matmul
+ x = torch.matmul(attn_weights, x)
+ # v: (num_heads, batch_size, seq_len, value_head_dim)
+
+ x = (
+ x.permute(2, 1, 0, 3)
+ .contiguous()
+ .view(seq_len, batch_size, num_heads * value_head_dim)
+ )
+
+ # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
+ x = self.out_proj(x)
+ x = self.whiten(x)
+
+ return x
+
+
+class FeedforwardModule(nn.Module):
+ """Feedforward module in TTSZipformer model."""
+
+ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike):
+ super(FeedforwardModule, self).__init__()
+ self.in_proj = nn.Linear(embed_dim, feedforward_dim)
+
+ self.hidden_balancer = Balancer(
+ feedforward_dim,
+ channel_dim=-1,
+ min_positive=0.3,
+ max_positive=1.0,
+ min_abs=0.75,
+ max_abs=5.0,
+ )
+
+ # shared_dim=0 means we share the dropout mask along the time axis
+ self.out_proj = ActivationDropoutAndLinear(
+ feedforward_dim,
+ embed_dim,
+ activation="SwooshL",
+ dropout_p=dropout,
+ dropout_shared_dim=0,
+ bias=True,
+ initial_scale=0.1,
+ )
+
+ self.out_whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(self, x: Tensor):
+ x = self.in_proj(x)
+ x = self.hidden_balancer(x)
+ # out_proj contains SwooshL activation, then dropout, then linear.
+ x = self.out_proj(x)
+ x = self.out_whiten(x)
+ return x
+
+
+class NonlinAttention(nn.Module):
+ """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
+ from the attention module) in place of actual convolution. We also took out the second nonlinearity, the
+ one after the attention mechanism.
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ hidden_channels: int,
+ ) -> None:
+ super().__init__()
+
+ self.hidden_channels = hidden_channels
+
+ self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
+
+ # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0,
+ # because we noticed that well-trained instances of this module have abs-value before the sigmoid
+ # starting from about 3, and poorly-trained instances of the module have smaller abs values
+ # before the sigmoid.
+ self.balancer = Balancer(
+ hidden_channels,
+ channel_dim=-1,
+ min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
+ max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
+ min_abs=0.5,
+ max_abs=5.0,
+ )
+ self.tanh = nn.Tanh()
+
+ self.identity1 = Identity() # for diagnostics.
+ self.identity2 = Identity() # for diagnostics.
+ self.identity3 = Identity() # for diagnostics.
+
+ self.out_proj = ScaledLinear(
+ hidden_channels, channels, bias=True, initial_scale=0.05
+ )
+
+ self.whiten1 = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(5.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.whiten2 = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(5.0, ratio=3.0),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ attn_weights: Tensor,
+ ) -> Tensor:
+ """.
+ Args:
+ x: a Tensor of shape (seq_len, batch_size, num_channels)
+ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
+ Returns:
+ a Tensor with the same shape as x
+ """
+ x = self.in_proj(x)
+
+ (seq_len, batch_size, _) = x.shape
+ hidden_channels = self.hidden_channels
+
+ s, x, y = x.chunk(3, dim=2)
+
+ # s will go through tanh.
+
+ s = self.balancer(s)
+ s = self.tanh(s)
+
+ s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
+ x = self.whiten1(x)
+ x = x * s
+ x = self.identity1(x) # diagnostics only, it's the identity.
+
+ (seq_len, batch_size, embed_dim) = x.shape
+ num_heads = attn_weights.shape[0]
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
+
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = torch.matmul(attn_weights, x)
+ # now x: (num_heads, batch_size, seq_len, head_dim)
+ x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
+
+ y = self.identity2(y)
+ x = x * y
+ x = self.identity3(x)
+
+ x = self.out_proj(x)
+ x = self.whiten2(x)
+ return x
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Zipformer2 model.
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ bias (bool): Whether to use bias in conv layers (default=True).
+
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ ) -> None:
+ """Construct a ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+
+ bottleneck_dim = channels
+
+ self.in_proj = nn.Linear(
+ channels,
+ 2 * bottleneck_dim,
+ )
+ # the gradients on in_proj are a little noisy, likely to do with the
+ # sigmoid in glu.
+
+ # after in_proj we put x through a gated linear unit (nn.functional.glu).
+ # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
+ # but sometimes, for some reason, for layer 0 the rms ends up being very large,
+ # between 50 and 100 for different channels. This will cause very peaky and
+ # sparse derivatives for the sigmoid gating function, which will tend to make
+ # the loss function not learn effectively. (for most layers the average absolute values
+ # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
+ # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
+ # layers, which likely breaks down as 0.5 for the "linear" half and
+ # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
+ # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
+ # it will be in a better position to start learning something, i.e. to latch onto
+ # the correct range.
+ self.balancer1 = Balancer(
+ bottleneck_dim,
+ channel_dim=-1,
+ min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
+ max_positive=1.0,
+ min_abs=1.5,
+ max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
+ )
+
+ self.activation1 = Identity() # for diagnostics
+
+ self.sigmoid = nn.Sigmoid()
+
+ self.activation2 = Identity() # for diagnostics
+
+ assert kernel_size % 2 == 1
+
+ self.depthwise_conv = nn.Conv1d(
+ in_channels=bottleneck_dim,
+ out_channels=bottleneck_dim,
+ groups=bottleneck_dim,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ )
+
+ self.balancer2 = Balancer(
+ bottleneck_dim,
+ channel_dim=1,
+ min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
+ max_positive=1.0,
+ min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
+ max_abs=10.0,
+ )
+
+ self.whiten = Whiten(
+ num_groups=1,
+ whitening_limit=_whitening_schedule(7.5),
+ prob=(0.025, 0.25),
+ grad_scale=0.01,
+ )
+
+ self.out_proj = ActivationDropoutAndLinear(
+ bottleneck_dim,
+ channels,
+ activation="SwooshR",
+ dropout_p=0.0,
+ initial_scale=0.05,
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """Compute convolution module.
+
+ Args:
+ x: Input tensor (#time, batch, channels).
+ src_key_padding_mask: the mask for the src keys per batch (optional):
+ (batch, #time), contains True in masked positions.
+
+ Returns:
+ Tensor: Output tensor (#time, batch, channels).
+
+ """
+
+ x = self.in_proj(x) # (time, batch, 2*channels)
+
+ x, s = x.chunk(2, dim=2)
+ s = self.balancer1(s)
+ s = self.sigmoid(s)
+ x = self.activation1(x) # identity.
+ x = x * s
+ x = self.activation2(x) # identity
+
+ # (time, batch, channels)
+
+ # exchange the temporal dimension and the feature dimension
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
+
+ if src_key_padding_mask is not None:
+ x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+
+ x = self.depthwise_conv(x)
+
+ x = self.balancer2(x)
+ x = x.permute(2, 0, 1) # (time, batch, channels)
+
+ x = self.whiten(x) # (time, batch, channels)
+ x = self.out_proj(x) # (time, batch, channels)
+
+ return x
diff --git a/egs/zipvoice/zipvoice/zipvoice_infer.py b/egs/zipvoice/zipvoice/zipvoice_infer.py
new file mode 100644
index 000000000..472ad700d
--- /dev/null
+++ b/egs/zipvoice/zipvoice/zipvoice_infer.py
@@ -0,0 +1,642 @@
+#!/usr/bin/env python3
+# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script generates speech with our pre-trained ZipVoice or
+ ZipVoice-Distill models. Required models will be automatically
+ downloaded from HuggingFace.
+
+Usage:
+
+Note: If you having trouble connecting to HuggingFace,
+ you try switch endpoint to mirror site:
+
+export HF_ENDPOINT=https://hf-mirror.com
+
+(1) Inference of a single sentence:
+
+python3 zipvoice/zipvoice_infer.py \
+ --model-name "zipvoice_distill" \
+ --prompt-wav prompt.wav \
+ --prompt-text "I am a prompt." \
+ --text "I am a sentence." \
+ --res-wav-path result.wav
+
+(2) Inference of a list of sentences:
+python3 zipvoice/zipvoice_infer.py \
+ --model-name "zipvoice-distill" \
+ --test-list test.tsv \
+ --res-dir results
+
+`--model-name` can be `zipvoice` or `zipvoice_distill`,
+ which are the models before and after distillation, respectively.
+
+Each line of `test.tsv` is in the format of
+ `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
+"""
+
+import argparse
+import datetime as dt
+import os
+
+import numpy as np
+import safetensors.torch
+import soundfile as sf
+import torch
+import torch.nn as nn
+import torchaudio
+from feature import TorchAudioFbank, TorchAudioFbankConfig
+from huggingface_hub import hf_hub_download
+from lhotse.utils import fix_random_seed
+from model import get_distill_model, get_model
+from tokenizer import TokenizerEmilia
+from utils import AttributeDict
+from vocos import Vocos
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="zipvoice_distill",
+ choices=["zipvoice", "zipvoice_distill"],
+ help="The model used for inference",
+ )
+
+ parser.add_argument(
+ "--test-list",
+ type=str,
+ default=None,
+ help="The list of prompt speech, prompt_transcription, "
+ "and text to synthesizein the format of "
+ "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
+ )
+
+ parser.add_argument(
+ "--prompt-wav",
+ type=str,
+ default=None,
+ help="The prompt wav to mimic",
+ )
+
+ parser.add_argument(
+ "--prompt-text",
+ type=str,
+ default=None,
+ help="The transcription of the prompt wav",
+ )
+
+ parser.add_argument(
+ "--text",
+ type=str,
+ default=None,
+ help="The text to synthesize",
+ )
+
+ parser.add_argument(
+ "--res-dir",
+ type=str,
+ default="results",
+ help="Path name of the generated wavs dir, "
+ "used when decdode-list is not None",
+ )
+
+ parser.add_argument(
+ "--res-wav-path",
+ type=str,
+ default="result.wav",
+ help="Path name of the generated wav path, " "used when decdode-list is None",
+ )
+
+ parser.add_argument(
+ "--guidance-scale",
+ type=float,
+ default=None,
+ help="The scale of classifier-free guidance during inference.",
+ )
+
+ parser.add_argument(
+ "--num-step",
+ type=int,
+ default=None,
+ help="The number of sampling steps.",
+ )
+
+ parser.add_argument(
+ "--feat-scale",
+ type=float,
+ default=0.1,
+ help="The scale factor of fbank feature",
+ )
+
+ parser.add_argument(
+ "--speed",
+ type=float,
+ default=1.0,
+ help="Control speech speed, 1.0 means normal, >1.0 means speed up",
+ )
+
+ parser.add_argument(
+ "--t-shift",
+ type=float,
+ default=0.5,
+ help="Shift t to smaller ones if t_shift < 1.0",
+ )
+
+ parser.add_argument(
+ "--target-rms",
+ type=float,
+ default=0.1,
+ help="Target speech normalization rms value",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=666,
+ help="Random seed",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--fm-decoder-downsampling-factor",
+ type=str,
+ default="1,2,4,2,1",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--fm-decoder-num-layers",
+ type=str,
+ default="2,2,4,4,4",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--fm-decoder-cnn-module-kernel",
+ type=str,
+ default="31,15,7,15,31",
+ help="Sizes of convolutional kernels in convolution modules "
+ "in each encoder stack: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--fm-decoder-feedforward-dim",
+ type=int,
+ default=1536,
+ help="Feedforward dimension of the zipformer encoder layers, "
+ "per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--fm-decoder-num-heads",
+ type=int,
+ default=4,
+ help="Number of attention heads in the zipformer encoder layers: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--fm-decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in encoder stacks: a single int "
+ "or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--text-encoder-downsampling-factor",
+ type=str,
+ default="1",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--text-encoder-num-layers",
+ type=str,
+ default="4",
+ help="Number of zipformer encoder layers per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--text-encoder-feedforward-dim",
+ type=int,
+ default=512,
+ help="Feedforward dimension of the zipformer encoder layers, "
+ "per stack, comma separated.",
+ )
+
+ parser.add_argument(
+ "--text-encoder-cnn-module-kernel",
+ type=str,
+ default="9",
+ help="Sizes of convolutional kernels in convolution modules in "
+ "each encoder stack: a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--text-encoder-num-heads",
+ type=int,
+ default=4,
+ help="Number of attention heads in the zipformer encoder layers: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--text-encoder-dim",
+ type=int,
+ default=192,
+ help="Embedding dimension in encoder stacks: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--query-head-dim",
+ type=int,
+ default=32,
+ help="Query/key dimension per head in encoder stacks: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--value-head-dim",
+ type=int,
+ default=12,
+ help="Value dimension per head in encoder stacks: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-head-dim",
+ type=int,
+ default=4,
+ help="Positional-encoding dimension per head in encoder stacks: "
+ "a single int or comma-separated list.",
+ )
+
+ parser.add_argument(
+ "--pos-dim",
+ type=int,
+ default=48,
+ help="Positional-encoding embedding dimension",
+ )
+
+ parser.add_argument(
+ "--time-embed-dim",
+ type=int,
+ default=192,
+ help="Embedding dimension of timestamps embedding.",
+ )
+
+ parser.add_argument(
+ "--text-embed-dim",
+ type=int,
+ default=192,
+ help="Embedding dimension of text embedding.",
+ )
+
+
+def get_params() -> AttributeDict:
+ params = AttributeDict(
+ {
+ "sampling_rate": 24000,
+ "frame_shift_ms": 256 / 24000 * 1000,
+ "feat_dim": 100,
+ }
+ )
+
+ return params
+
+
+def get_vocoder():
+ vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
+ return vocoder
+
+
+def generate_sentence(
+ save_path: str,
+ prompt_text: str,
+ prompt_wav: str,
+ text: str,
+ model: nn.Module,
+ vocoder: nn.Module,
+ tokenizer: TokenizerEmilia,
+ feature_extractor: TorchAudioFbank,
+ device: torch.device,
+ num_step: int = 16,
+ guidance_scale: float = 1.0,
+ speed: float = 1.0,
+ t_shift: float = 0.5,
+ target_rms: float = 0.1,
+ feat_scale: float = 0.1,
+ sampling_rate: int = 24000,
+):
+ """
+ Generate waveform of a text based on a given prompt
+ waveform and its transcription.
+
+ Args:
+ save_path (str): Path to save the generated wav.
+ prompt_text (str): Transcription of the prompt wav.
+ prompt_wav (str): Path to the prompt wav file.
+ text (str): Text to be synthesized into a waveform.
+ model (nn.Module): The model used for generation.
+ vocoder (nn.Module): The vocoder used to convert features to waveforms.
+ tokenizer (TokenizerEmilia): The tokenizer used to convert text to tokens.
+ feature_extractor (TorchAudioFbank): The feature extractor used to
+ extract acoustic features.
+ device (torch.device): The device on which computations are performed.
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
+ guidance_scale (float, optional): Scale for classifier-free guidance.
+ Defaults to 1.0.
+ speed (float, optional): Speed control. Defaults to 1.0.
+ t_shift (float, optional): Time shift. Defaults to 0.5.
+ target_rms (float, optional): Target RMS for waveform normalization.
+ Defaults to 0.1.
+ feat_scale (float, optional): Scale for features.
+ Defaults to 0.1.
+ sampling_rate (int, optional): Sampling rate for the waveform.
+ Defaults to 24000.
+ Returns:
+ metrics (dict): Dictionary containing time and real-time
+ factor metrics for processing.
+ """
+ # Convert text to tokens
+ tokens = tokenizer.texts_to_token_ids([text])
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
+
+ # Load and preprocess prompt wav
+ prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
+ if prompt_rms < target_rms:
+ prompt_wav = prompt_wav * target_rms / prompt_rms
+
+ if prompt_sampling_rate != sampling_rate:
+ resampler = torchaudio.transforms.Resample(
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
+ )
+ prompt_wav = resampler(prompt_wav)
+
+ # Extract features from prompt wav
+ prompt_features = feature_extractor.extract(
+ prompt_wav, sampling_rate=sampling_rate
+ ).to(device)
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
+
+ # Start timing
+ start_t = dt.datetime.now()
+
+ # Generate features
+ (
+ pred_features,
+ pred_features_lens,
+ pred_prompt_features,
+ pred_prompt_features_lens,
+ ) = model.sample(
+ tokens=tokens,
+ prompt_tokens=prompt_tokens,
+ prompt_features=prompt_features,
+ prompt_features_lens=prompt_features_lens,
+ speed=speed,
+ t_shift=t_shift,
+ duration="predict",
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ )
+
+ # Postprocess predicted features
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
+
+ # Start vocoder processing
+ start_vocoder_t = dt.datetime.now()
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
+
+ # Calculate processing times and real-time factors
+ t = (dt.datetime.now() - start_t).total_seconds()
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
+ wav_seconds = wav.shape[-1] / sampling_rate
+ rtf = t / wav_seconds
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
+ rtf_vocoder = t_vocoder / wav_seconds
+ metrics = {
+ "t": t,
+ "t_no_vocoder": t_no_vocoder,
+ "t_vocoder": t_vocoder,
+ "wav_seconds": wav_seconds,
+ "rtf": rtf,
+ "rtf_no_vocoder": rtf_no_vocoder,
+ "rtf_vocoder": rtf_vocoder,
+ }
+
+ # Adjust wav volume if necessary
+ if prompt_rms < target_rms:
+ wav = wav * prompt_rms / target_rms
+ wav = wav[0].cpu().numpy()
+ sf.write(save_path, wav, sampling_rate)
+
+ return metrics
+
+
+def generate(
+ res_dir: str,
+ test_list: str,
+ model: nn.Module,
+ vocoder: nn.Module,
+ tokenizer: TokenizerEmilia,
+ feature_extractor: TorchAudioFbank,
+ device: torch.device,
+ num_step: int = 16,
+ guidance_scale: float = 1.0,
+ speed: float = 1.0,
+ t_shift: float = 0.5,
+ target_rms: float = 0.1,
+ feat_scale: float = 0.1,
+ sampling_rate: int = 24000,
+):
+ total_t = []
+ total_t_no_vocoder = []
+ total_t_vocoder = []
+ total_wav_seconds = []
+
+ with open(test_list, "r") as fr:
+ lines = fr.readlines()
+
+ for i, line in enumerate(lines):
+ wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
+ save_path = f"{res_dir}/{wav_name}.wav"
+ metrics = generate_sentence(
+ save_path=save_path,
+ prompt_text=prompt_text,
+ prompt_wav=prompt_wav,
+ text=text,
+ model=model,
+ vocoder=vocoder,
+ tokenizer=tokenizer,
+ feature_extractor=feature_extractor,
+ device=device,
+ num_step=num_step,
+ guidance_scale=guidance_scale,
+ speed=speed,
+ t_shift=t_shift,
+ target_rms=target_rms,
+ feat_scale=feat_scale,
+ sampling_rate=sampling_rate,
+ )
+ print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
+ total_t.append(metrics["t"])
+ total_t_no_vocoder.append(metrics["t_no_vocoder"])
+ total_t_vocoder.append(metrics["t_vocoder"])
+ total_wav_seconds.append(metrics["wav_seconds"])
+
+ print(f"Average RTF: {np.sum(total_t)/np.sum(total_wav_seconds):.4f}")
+ print(
+ f"Average RTF w/o vocoder: "
+ f"{np.sum(total_t_no_vocoder)/np.sum(total_wav_seconds):.4f}"
+ )
+ print(
+ f"Average RTF vocoder: "
+ f"{np.sum(total_t_vocoder)/np.sum(total_wav_seconds):.4f}"
+ )
+
+
+@torch.inference_mode()
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ params = get_params()
+ params.update(vars(args))
+
+ model_defaults = {
+ "zipvoice": {
+ "num_step": 16,
+ "guidance_scale": 1.0,
+ },
+ "zipvoice_distill": {
+ "num_step": 8,
+ "guidance_scale": 3.0,
+ },
+ }
+
+ model_specific_defaults = model_defaults.get(params.model_name, {})
+
+ for param, value in model_specific_defaults.items():
+ if getattr(params, param) == parser.get_default(param):
+ setattr(params, param, value)
+ print(f"Setting {param} to default value: {value}")
+
+ assert (params.test_list is not None) ^ (
+ (params.prompt_wav and params.prompt_text and params.text) is not None
+ ), (
+ "For inference, please provide prompts and text with either '--test-list'"
+ " or '--prompt-wav, --prompt-text and --text'."
+ )
+
+ if torch.cuda.is_available():
+ params.device = torch.device("cuda", 0)
+ else:
+ params.device = torch.device("cpu")
+
+ token_file = hf_hub_download("zhu-han/ZipVoice", filename="tokens_emilia.txt")
+
+ tokenizer = TokenizerEmilia(token_file)
+
+ params.vocab_size = tokenizer.vocab_size
+ params.pad_id = tokenizer.pad_id
+ fix_random_seed(params.seed)
+
+ if params.model_name == "zipvoice_distill":
+ model = get_distill_model(params)
+ model_ckpt = hf_hub_download(
+ "zhu-han/ZipVoice", filename="exp_zipvoice_distill/model.safetensors"
+ )
+ else:
+ model = get_model(params)
+ model_ckpt = hf_hub_download(
+ "zhu-han/ZipVoice", filename="exp_zipvoice/model.safetensors"
+ )
+
+ safetensors.torch.load_model(model, model_ckpt)
+
+ model = model.to(params.device)
+ model.eval()
+
+ vocoder = get_vocoder()
+ vocoder = vocoder.to(params.device)
+ vocoder.eval()
+
+ config = TorchAudioFbankConfig(
+ sampling_rate=params.sampling_rate,
+ n_mels=100,
+ n_fft=1024,
+ hop_length=256,
+ )
+ feature_extractor = TorchAudioFbank(config)
+
+ if params.test_list:
+ os.makedirs(params.res_dir, exist_ok=True)
+ generate(
+ res_dir=params.res_dir,
+ test_list=params.test_list,
+ model=model,
+ vocoder=vocoder,
+ tokenizer=tokenizer,
+ feature_extractor=feature_extractor,
+ device=params.device,
+ num_step=params.num_step,
+ guidance_scale=params.guidance_scale,
+ speed=params.speed,
+ t_shift=params.t_shift,
+ target_rms=params.target_rms,
+ feat_scale=params.feat_scale,
+ sampling_rate=params.sampling_rate,
+ )
+ else:
+ generate_sentence(
+ save_path=params.res_wav_path,
+ prompt_text=params.prompt_text,
+ prompt_wav=params.prompt_wav,
+ text=params.text,
+ model=model,
+ vocoder=vocoder,
+ tokenizer=tokenizer,
+ feature_extractor=feature_extractor,
+ device=params.device,
+ num_step=params.num_step,
+ guidance_scale=params.guidance_scale,
+ speed=params.speed,
+ t_shift=params.t_shift,
+ target_rms=params.target_rms,
+ feat_scale=params.feat_scale,
+ sampling_rate=params.sampling_rate,
+ )
+ print("Done")
+
+
+if __name__ == "__main__":
+ main()