diff --git a/egs/zipvoice/README.md b/egs/zipvoice/README.md new file mode 100644 index 000000000..0eed8f540 --- /dev/null +++ b/egs/zipvoice/README.md @@ -0,0 +1,360 @@ +## ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching + + +[![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](http://arxiv.org/abs/2506.13053) +[![demo](https://img.shields.io/badge/GitHub-Demo%20page-orange.svg)](https://zipvoice.github.io/) + + +## Overview +ZipVoice is a high-quality zero-shot TTS model with a small model size and fast inference speed. +#### Key features: + +- Small and fast: only 123M parameters. + +- High-quality: state-of-the-art voice cloning performance in speaker similarity, intelligibility, and naturalness. + +- Multi-lingual: support Chinese and English. + + +## News +**2025/06/16**: 🔥 ZipVoice is released. + + +## Installation +``` +pip install -r requirements.txt +``` + +## Usage + +To generate speech with our pre-trained ZipVoice or ZipVoice-Distill models, use the following commands (Required models will be downloaded from HuggingFace): + +### 1. Inference of a single sentence: +```bash +python3 zipvoice/zipvoice_infer.py \ + --model-name "zipvoice_distill" \ + --prompt-wav prompt.wav \ + --prompt-text "I am the transcription of the prompt wav." \ + --text "I am the text to be synthesized." \ + --res-wav-path result.wav +``` + +### 2. Inference of a list of sentences: +```bash +python3 zipvoice/zipvoice_infer.py \ + --model-name "zipvoice_distill" \ + --test-list test.tsv \ + --res-dir results/test +``` +- `--model-name` can be `zipvoice` or `zipvoice_distill`, which are models before and after distillation, respectively. +- Each line of `test.tsv` is in the format of `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`. + + +> **Note:** If you having trouble connecting to HuggingFace, try: +```bash +export HF_ENDPOINT=https://hf-mirror.com +``` + +## Training Your Own Model + +The following steps show how to train a model from scratch on Emilia and LibriTTS datasets, respectively. + +### 1. Data Preparation + +#### 1.1. Prepare the Emilia dataset + +#### 1.2 Prepare the LibriTTS dataset + +See [local/prepare_libritts.sh](local/prepare_libritts.sh) + +### 2. Training + +#### 2.1 Traininig on Emilia + +
+Expand to view training steps + +##### 2.1.1 Train the ZipVoice model + +- Training: + +```bash +export PYTHONPATH=../../:$PYTHONPATH +python3 zipvoice/train_flow.py \ + --world-size 8 \ + --use-fp16 1 \ + --dataset emilia \ + --max-duration 500 \ + --lr-hours 30000 \ + --lr-batches 7500 \ + --token-file "data/tokens_emilia.txt" \ + --manifest-dir "data/fbank_emilia" \ + --num-epochs 11 \ + --exp-dir zipvoice/exp_zipvoice +``` + +- Average the checkpoints to produce the final model: + +```bash +export PYTHONPATH=../../:$PYTHONPATH +python3 zipvoice/generate_averaged_model.py \ + --epoch 11 \ + --avg 4 \ + --distill 0 \ + --token-file data/tokens_emilia.txt \ + --dataset "emilia" \ + --exp-dir ./zipvoice/exp_zipvoice +# The generated model is zipvoice/exp_zipvoice/epoch-11-avg-4.pt +``` + +##### 2.1.2. Train the ZipVoice-Distill model (Optional) + +- The first-stage distillation: + +```bash +export PYTHONPATH=../../:$PYTHONPATH +python3 zipvoice/train_distill.py \ + --world-size 8 \ + --use-fp16 1 \ + --tensorboard 1 \ + --dataset "emilia" \ + --base-lr 0.0005 \ + --max-duration 500 \ + --token-file "data/tokens_emilia.txt" \ + --manifest-dir "data/fbank_emilia" \ + --teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \ + --num-updates 60000 \ + --distill-stage "first" \ + --exp-dir zipvoice/exp_zipvoice_distill_1stage +``` + +- Average checkpoints for the second-stage initialization: + +```bash +export PYTHONPATH=../../:$PYTHONPATH +python3 zipvoice/generate_averaged_model.py \ + --iter 60000 \ + --avg 7 \ + --distill 1 \ + --token-file data/tokens_emilia.txt \ + --dataset "emilia" \ + --exp-dir ./zipvoice/exp_zipvoice_distill_1stage +# The generated model is zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt +``` + +- The second-stage distillation: + +```bash +export PYTHONPATH=../../:$PYTHONPATH +python3 zipvoice/train_distill.py \ + --world-size 8 \ + --use-fp16 1 \ + --tensorboard 1 \ + --dataset "emilia" \ + --base-lr 0.0001 \ + --max-duration 200 \ + --token-file "data/tokens_emilia.txt" \ + --manifest-dir "data/fbank_emilia" \ + --teacher-model zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt \ + --num-updates 2000 \ + --distill-stage "second" \ + --exp-dir zipvoice/exp_zipvoice_distill_new +``` +
+ + +#### 2.2 Traininig on LibriTTS + +
+Expand to view training steps + +##### 2.2.1 Train the ZipVoice model + +- Training: + +```bash +export PYTHONPATH=../../:$PYTHONPATH +python3 zipvoice/train_flow.py \ + --world-size 8 \ + --use-fp16 1 \ + --dataset libritts \ + --max-duration 250 \ + --lr-epochs 10 \ + --lr-batches 7500 \ + --token-file "data/tokens_libritts.txt" \ + --manifest-dir "data/fbank_libritts" \ + --num-epochs 60 \ + --exp-dir zipvoice/exp_zipvoice_libritts +``` + +- Average the checkpoints to produce the final model: + +```bash +export PYTHONPATH=../../:$PYTHONPATH +python3 zipvoice/generate_averaged_model.py \ + --epoch 60 \ + --avg 10 \ + --distill 0 \ + --token-file data/tokens_libritts.txt \ + --dataset "libritts" \ + --exp-dir ./zipvoice/exp_zipvoice_libritts +# The generated model is zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt +``` + +##### 2.1.2 Train the ZipVoice-Distill model (Optional) + +- The first-stage distillation: + +```bash +export PYTHONPATH=../../:$PYTHONPATH +python3 zipvoice/train_distill.py \ + --world-size 8 \ + --use-fp16 1 \ + --tensorboard 1 \ + --dataset "libritts" \ + --base-lr 0.001 \ + --max-duration 250 \ + --token-file "data/tokens_libritts.txt" \ + --manifest-dir "data/fbank_libritts" \ + --teacher-model zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt \ + --num-epochs 6 \ + --distill-stage "first" \ + --exp-dir zipvoice/exp_zipvoice_distill_1stage_libritts +``` + +- Average checkpoints for the second-stage initialization: + +```bash +export PYTHONPATH=../../:$PYTHONPATH +python3 ./zipvoice/generate_averaged_model.py \ + --epoch 6 \ + --avg 3 \ + --distill 1 \ + --token-file data/tokens_libritts.txt \ + --dataset "libritts" \ + --exp-dir ./zipvoice/exp_zipvoice_distill_1stage_libritts +# The generated model is zipvoice/exp_zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt +``` + +- The second-stage distillation: + +```bash +export PYTHONPATH=../../:$PYTHONPATH +python3 zipvoice/train_distill.py \ + --world-size 8 \ + --use-fp16 1 \ + --tensorboard 1 \ + --dataset "libritts" \ + --base-lr 0.001 \ + --max-duration 250 \ + --token-file "data/tokens_libritts.txt" \ + --manifest-dir "data/fbank_libritts" \ + --teacher-model zipvoice/exp_zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt \ + --num-epochs 6 \ + --distill-stage "second" \ + --exp-dir zipvoice/exp_zipvoice_distill_libritts +``` + +- Average checkpoints to produce the final model: + +```bash +export PYTHONPATH=../../:$PYTHONPATH +python3 ./zipvoice/generate_averaged_model.py \ + --epoch 6 \ + --avg 3 \ + --distill 1 \ + --token-file data/tokens_libritts.txt \ + --dataset "libritts" \ + --exp-dir ./zipvoice/exp_zipvoice_distill_libritts +# The generated model is ./zipvoice/exp_zipvoice_distill_libritts/epoch-6-avg-3.pt +``` +
+ + +### 3. Inference with the trained model + +#### 3.1 Inference with the model trained on Emilia +
+Expand to view inference commands. + +##### 3.1.1 ZipVoice model before distill: +```bash +export PYTHONPATH=../../:$PYTHONPATH +python3 zipvoice/infer.py \ + --checkpoint zipvoice/exp_zipvoice/epoch-11-avg-4.pt \ + --distill 0 \ + --token-file "data/tokens_emilia.txt" \ + --test-list test.tsv \ + --res-dir results/test \ + --num-step 16 \ + --guidance-scale 1 +``` + +##### 3.1.2 ZipVoice-Distill model before distill: +```bash +export PYTHONPATH=../../:$PYTHONPATH +python3 zipvoice/infer.py \ + --checkpoint zipvoice/exp_zipvoice_distill/checkpoint-2000.pt \ + --distill 1 \ + --token-file "data/tokens_emilia.txt" \ + --test-list test.tsv \ + --res-dir results/test_distill \ + --num-step 8 \ + --guidance-scale 3 +``` +
+ + +#### 3.2 Inference with the model trained on LibriTTS + +
+Expand to view inference commands. + +##### 3.2.1 ZipVoice model before distill: +```bash +export PYTHONPATH=../../:$PYTHONPATH +python3 zipvoice/infer.py \ + --checkpoint zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt \ + --distill 0 \ + --token-file "data/tokens_libritts.txt" \ + --test-list test.tsv \ + --res-dir results/test_libritts \ + --num-step 8 \ + --guidance-scale 1 \ + --target-rms 1.0 \ + --t-shift 0.7 +``` + +##### 3.2.2 ZipVoice-Distill model before distill + +```bash +export PYTHONPATH=../../:$PYTHONPATH +python3 zipvoice/infer.py \ + --checkpoint zipvoice/exp_zipvoice_distill/epoch-6-avg-3.pt \ + --distill 1 \ + --token-file "data/tokens_libritts.txt" \ + --test-list test.tsv \ + --res-dir results/test_distill_libritts \ + --num-step 4 \ + --guidance-scale 3 \ + --target-rms 1.0 \ + --t-shift 0.7 +``` +
+ +### 4. Evaluation on benchmarks + +See [local/evaluate.sh](local/evaluate.sh) for details of objective metrics evaluation +on three test sets, i.e., LibriSpeech-PC test-clean, Seed-TTS test-en and Seed-TTS test-zh. + + +## Citation + +```bibtex +@article{zhu-2025-zipvoice, + title={ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching}, + author={Han Zhu and Wei Kang and Zengwei Yao and Liyong Guo and Fangjun Kuang and Zhaoqing Li and Weiji Zhuang and Long Lin and Daniel Povey} + journal={arXiv preprint arXiv:2506.13053}, + year={2025}, +} +``` \ No newline at end of file diff --git a/egs/zipvoice/local/compute_fbank_libritts.py b/egs/zipvoice/local/compute_fbank_libritts.py new file mode 100755 index 000000000..0c9f464ea --- /dev/null +++ b/egs/zipvoice/local/compute_fbank_libritts.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao,) +# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LibriTTS dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import torch +from feature import TorchAudioFbank, TorchAudioFbankConfig +from lhotse import CutSet, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + parser.add_argument( + "--sampling-rate", + type=int, + default=24000, + help="""Sampling rate of the waveform for computing fbank, + the default value for LibriTTS is 24000, waveform files will be + resampled if a different sample rate is provided""", + ) + + return parser.parse_args() + + +def compute_fbank_libritts(dataset: Optional[str] = None, sampling_rate: int = 24000): + src_dir = Path("data/manifests_libritts") + output_dir = Path("data/fbank_libritts") + num_jobs = min(32, os.cpu_count()) + + prefix = "libritts" + suffix = "jsonl.gz" + if dataset is None: + dataset_parts = ( + "dev-clean", + "test-clean", + "train-clean-100", + "train-clean-360", + "train-other-500", + ) + else: + dataset_parts = dataset.split(" ", -1) + + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + config = TorchAudioFbankConfig( + sampling_rate=sampling_rate, + n_mels=100, + n_fft=1024, + hop_length=256, + ) + extractor = TorchAudioFbank(config) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if sampling_rate != 24000: + logging.info(f"Resampling waveforms to {sampling_rate}") + cut_set = cut_set.resample(sampling_rate) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_fbank_libritts() diff --git a/egs/zipvoice/local/evaluate.sh b/egs/zipvoice/local/evaluate.sh new file mode 100644 index 000000000..fbc0eed9a --- /dev/null +++ b/egs/zipvoice/local/evaluate.sh @@ -0,0 +1,102 @@ +export CUDA_VISIBLE_DEVICES="0" +export PYTHONWARNINGS=ignore +export PYTHONPATH=../../:$PYTHONPATH + +# Uncomment this if you have trouble connecting to HuggingFace +# export HF_ENDPOINT=https://hf-mirror.com + +start_stage=1 +end_stage=3 + +# Models used for SIM-o evaluation. +# SV model wavlm_large_finetune.pth is downloaded from https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification +# SSL model wavlm_large.pt is downloaded from https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt +sv_model_path=model/UniSpeech/wavlm_large_finetune.pth +wavlm_model_path=model/s3prl/wavlm_large.pt + +# Models used for UTMOS evaluation. +# wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt +# wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt +utmos_model_path=model/huggingface/utmos/utmos.pt +wav2vec_model_path=model/huggingface/utmos/wav2vec_small.pt + + +if [ $start_stage -le 1 ] && [ $end_stage -ge 1 ]; then + + echo "=====Evaluate for Seed-TTS test-en=======" + test_list=testset/test_seedtts_en.tsv + wav_path=results/zipvoice_seedtts_en + + echo $wav_path + echo "-----Computing SIM-o-----" + python3 local/evaluate_sim.py \ + --sv-model-path ${sv_model_path} \ + --ssl-model-path ${wavlm_model_path} \ + --eval-path ${wav_path} \ + --test-list ${test_list} + + echo "-----Computing WER-----" + python3 local/evaluate_wer_seedtts.py \ + --test-list ${test_list} \ + --wav-path ${wav_path} \ + --lang "en" + + echo "-----Computing UTSMOS-----" + python3 local/evaluate_utmos.py \ + --wav-path ${wav_path} \ + --utmos-model-path ${utmos_model_path} \ + --ssl-model-path ${wav2vec_model_path} + +fi + +if [ $start_stage -le 2 ] && [ $end_stage -ge 2 ]; then + echo "=====Evaluate for Seed-TTS test-zh=======" + test_list=testset/test_seedtts_zh.tsv + wav_path=results/zipvoice_seedtts_zh + + echo $wav_path + echo "-----Computing SIM-o-----" + python3 local/evaluate_sim.py \ + --sv-model-path ${sv_model_path} \ + --ssl-model-path ${wavlm_model_path} \ + --eval-path ${wav_path} \ + --test-list ${test_list} + + echo "-----Computing WER-----" + python3 local/evaluate_wer_seedtts.py \ + --test-list ${test_list} \ + --wav-path ${wav_path} \ + --lang "zh" + + echo "-----Computing UTSMOS-----" + python3 local/evaluate_utmos.py \ + --wav-path ${wav_path} \ + --utmos-model-path ${utmos_model_path} \ + --ssl-model-path ${wav2vec_model_path} +fi + +if [ $start_stage -le 3 ] && [ $end_stage -ge 3 ]; then + echo "=====Evaluate for Librispeech test-clean=======" + test_list=testset/test_librispeech_pc_test_clean.tsv + wav_path=results/zipvoice_librispeech_test_clean + + echo $wav_path + echo "-----Computing SIM-o-----" + python3 local/evaluate_sim.py \ + --sv-model-path ${sv_model_path} \ + --ssl-model-path ${wavlm_model_path} \ + --eval-path ${wav_path} \ + --test-list ${test_list} + + echo "-----Computing WER-----" + python3 local/evaluate_wer_hubert.py \ + --test-list ${test_list} \ + --wav-path ${wav_path} \ + + echo "-----Computing UTSMOS-----" + python3 local/evaluate_utmos.py \ + --wav-path ${wav_path} \ + --utmos-model-path ${utmos_model_path} \ + --ssl-model-path ${wav2vec_model_path} + +fi \ No newline at end of file diff --git a/egs/zipvoice/local/evaluate_sim.py b/egs/zipvoice/local/evaluate_sim.py new file mode 100644 index 000000000..df439cf2c --- /dev/null +++ b/egs/zipvoice/local/evaluate_sim.py @@ -0,0 +1,508 @@ +""" +Calculate pairwise Speaker Similarity betweeen two speech directories. +SV model wavlm_large_finetune.pth is downloaded from + https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification +SSL model wavlm_large.pt is downloaded from + https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt +""" +import argparse +import logging +import os + +import librosa +import numpy as np +import soundfile as sf +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +logging.basicConfig(level=logging.INFO) + + +def get_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--eval-path", type=str, help="path of the evaluated speech directory" + ) + parser.add_argument( + "--test-list", + type=str, + help="path of the file list that contains the corresponding " + "relationship between the prompt and evaluated speech. " + "The first column is the wav name and the third column is the prompt speech", + ) + parser.add_argument( + "--sv-model-path", + type=str, + default="model/UniSpeech/wavlm_large_finetune.pth", + help="path of the wavlm-based ECAPA-TDNN model", + ) + parser.add_argument( + "--ssl-model-path", + type=str, + default="model/s3prl/wavlm_large.pt", + help="path of the wavlm SSL model", + ) + return parser + + +class SpeakerSimilarity: + def __init__( + self, + sv_model_path="model/UniSpeech/wavlm_large_finetune.pth", + ssl_model_path="model/s3prl/wavlm_large.pt", + ): + """ + Initialize + """ + self.sample_rate = 16000 + self.channels = 1 + self.device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + logging.info("[Speaker Similarity] Using device: {}".format(self.device)) + self.model = ECAPA_TDNN_WAVLLM( + feat_dim=1024, + channels=512, + emb_dim=256, + sr=16000, + ssl_model_path=ssl_model_path, + ) + state_dict = torch.load( + sv_model_path, map_location=lambda storage, loc: storage + ) + self.model.load_state_dict(state_dict["model"], strict=False) + self.model.to(self.device) + self.model.eval() + + def get_embeddings(self, wav_list, dtype="float32"): + """ + Get embeddings + """ + + def _load_speech_task(fname, sample_rate): + + wav_data, sr = sf.read(fname, dtype=dtype) + if sr != sample_rate: + wav_data = librosa.resample( + wav_data, orig_sr=sr, target_sr=self.sample_rate + ) + wav_data = torch.from_numpy(wav_data) + + return wav_data + + embd_lst = [] + for file_path in tqdm(wav_list): + speech = _load_speech_task(file_path, self.sample_rate) + speech = speech.to(self.device) + with torch.no_grad(): + embd = self.model([speech]) + embd_lst.append(embd) + + return embd_lst + + def score( + self, + eval_path, + test_list, + dtype="float32", + ): + """ + Computes the Speaker Similarity (SIM-o) between two directories of speech files. + + Parameters: + - eval_path (str): Path to the directory containing evaluation speech files. + - test_list (str): Path to the file containing the corresponding relationship + between prompt and evaluated speech. + - dtype (str, optional): Data type for loading speech. Default is "float32". + + Returns: + - float: The Speaker Similarity (SIM-o) score between the two directories + of speech files. + """ + prompt_wavs = [] + eval_wavs = [] + with open(test_list, "r") as fr: + lines = fr.readlines() + for line in lines: + wav_name, prompt_text, prompt_wav, text = line.strip().split("\t") + prompt_wavs.append(prompt_wav) + eval_wavs.append(os.path.join(eval_path, wav_name + ".wav")) + embds_prompt = self.get_embeddings(prompt_wavs, dtype=dtype) + + embds_eval = self.get_embeddings(eval_wavs, dtype=dtype) + + # Check if embeddings are empty + if len(embds_prompt) == 0: + logging.info("[Speaker Similarity] real set dir is empty, exiting...") + return -1 + if len(embds_eval) == 0: + logging.info("[Speaker Similarity] eval set dir is empty, exiting...") + return -1 + + scores = [] + for real_embd, eval_embd in zip(embds_prompt, embds_eval): + scores.append( + torch.nn.functional.cosine_similarity(real_embd, eval_embd, dim=-1) + .detach() + .cpu() + .numpy() + ) + + return np.mean(scores) + + +# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN + +""" Res2Conv1d + BatchNorm1d + ReLU +""" + + +class Res2Conv1dReluBn(nn.Module): + """ + in_channels == out_channels == channels + """ + + def __init__( + self, + channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=True, + scale=4, + ): + super().__init__() + assert channels % scale == 0, "{} % {} != 0".format(channels, scale) + self.scale = scale + self.width = channels // scale + self.nums = scale if scale == 1 else scale - 1 + + self.convs = [] + self.bns = [] + for i in range(self.nums): + self.convs.append( + nn.Conv1d( + self.width, + self.width, + kernel_size, + stride, + padding, + dilation, + bias=bias, + ) + ) + self.bns.append(nn.BatchNorm1d(self.width)) + self.convs = nn.ModuleList(self.convs) + self.bns = nn.ModuleList(self.bns) + + def forward(self, x): + out = [] + spx = torch.split(x, self.width, 1) + for i in range(self.nums): + if i == 0: + sp = spx[i] + else: + sp = sp + spx[i] + # Order: conv -> relu -> bn + sp = self.convs[i](sp) + sp = self.bns[i](F.relu(sp)) + out.append(sp) + if self.scale != 1: + out.append(spx[self.nums]) + out = torch.cat(out, dim=1) + + return out + + +""" Conv1d + BatchNorm1d + ReLU +""" + + +class Conv1dReluBn(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=True, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias + ) + self.bn = nn.BatchNorm1d(out_channels) + + def forward(self, x): + return self.bn(F.relu(self.conv(x))) + + +""" The SE connection of 1D case. +""" + + +class SE_Connect(nn.Module): + def __init__(self, channels, se_bottleneck_dim=128): + super().__init__() + self.linear1 = nn.Linear(channels, se_bottleneck_dim) + self.linear2 = nn.Linear(se_bottleneck_dim, channels) + + def forward(self, x): + out = x.mean(dim=2) + out = F.relu(self.linear1(out)) + out = torch.sigmoid(self.linear2(out)) + out = x * out.unsqueeze(2) + + return out + + +""" SE-Res2Block of the ECAPA-TDNN architecture. +""" + + +# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale): +# return nn.Sequential( +# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0), +# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale), +# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0), +# SE_Connect(channels) +# ) + + +class SE_Res2Block(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + scale, + se_bottleneck_dim, + ): + super().__init__() + self.Conv1dReluBn1 = Conv1dReluBn( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.Res2Conv1dReluBn = Res2Conv1dReluBn( + out_channels, kernel_size, stride, padding, dilation, scale=scale + ) + self.Conv1dReluBn2 = Conv1dReluBn( + out_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim) + + self.shortcut = None + if in_channels != out_channels: + self.shortcut = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + ) + + def forward(self, x): + residual = x + if self.shortcut: + residual = self.shortcut(x) + + x = self.Conv1dReluBn1(x) + x = self.Res2Conv1dReluBn(x) + x = self.Conv1dReluBn2(x) + x = self.SE_Connect(x) + + return x + residual + + +""" Attentive weighted mean and standard deviation pooling. +""" + + +class AttentiveStatsPool(nn.Module): + def __init__(self, in_dim, attention_channels=128, global_context_att=False): + super().__init__() + self.global_context_att = global_context_att + + # Use Conv1d with stride == 1 rather than Linear, + # then we don't need to transpose inputs. + if global_context_att: + self.linear1 = nn.Conv1d( + in_dim * 3, attention_channels, kernel_size=1 + ) # equals W and b in the paper + else: + self.linear1 = nn.Conv1d( + in_dim, attention_channels, kernel_size=1 + ) # equals W and b in the paper + self.linear2 = nn.Conv1d( + attention_channels, in_dim, kernel_size=1 + ) # equals V and k in the paper + + def forward(self, x): + + if self.global_context_att: + context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) + context_std = torch.sqrt( + torch.var(x, dim=-1, keepdim=True) + 1e-10 + ).expand_as(x) + x_in = torch.cat((x, context_mean, context_std), dim=1) + else: + x_in = x + + # DON'T use ReLU here! In experiments, I find ReLU hard to converge. + alpha = torch.tanh(self.linear1(x_in)) + # alpha = F.relu(self.linear1(x_in)) + alpha = torch.softmax(self.linear2(alpha), dim=2) + mean = torch.sum(alpha * x, dim=2) + residuals = torch.sum(alpha * (x**2), dim=2) - mean**2 + std = torch.sqrt(residuals.clamp(min=1e-9)) + return torch.cat([mean, std], dim=1) + + +class ECAPA_TDNN_WAVLLM(nn.Module): + def __init__( + self, + feat_dim=80, + channels=512, + emb_dim=192, + global_context_att=False, + sr=16000, + ssl_model_path=None, + ): + super().__init__() + self.sr = sr + + if ssl_model_path is None: + self.feature_extract = torch.hub.load("s3prl/s3prl", "wavlm_large") + else: + self.feature_extract = torch.hub.load( + os.path.dirname(ssl_model_path), + "wavlm_local", + source="local", + ckpt=ssl_model_path, + ) + + if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( + self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention" + ): + self.feature_extract.model.encoder.layers[ + 23 + ].self_attn.fp32_attention = False + if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( + self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention" + ): + self.feature_extract.model.encoder.layers[ + 11 + ].self_attn.fp32_attention = False + + self.feat_num = self.get_feat_num() + self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) + + self.instance_norm = nn.InstanceNorm1d(feat_dim) + # self.channels = [channels] * 4 + [channels * 3] + self.channels = [channels] * 4 + [1536] + + self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2) + self.layer2 = SE_Res2Block( + self.channels[0], + self.channels[1], + kernel_size=3, + stride=1, + padding=2, + dilation=2, + scale=8, + se_bottleneck_dim=128, + ) + self.layer3 = SE_Res2Block( + self.channels[1], + self.channels[2], + kernel_size=3, + stride=1, + padding=3, + dilation=3, + scale=8, + se_bottleneck_dim=128, + ) + self.layer4 = SE_Res2Block( + self.channels[2], + self.channels[3], + kernel_size=3, + stride=1, + padding=4, + dilation=4, + scale=8, + se_bottleneck_dim=128, + ) + + # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1) + cat_channels = channels * 3 + self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) + self.pooling = AttentiveStatsPool( + self.channels[-1], + attention_channels=128, + global_context_att=global_context_att, + ) + self.bn = nn.BatchNorm1d(self.channels[-1] * 2) + self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) + + def get_feat_num(self): + self.feature_extract.eval() + wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)] + with torch.no_grad(): + features = self.feature_extract(wav) + select_feature = features["hidden_states"] + if isinstance(select_feature, (list, tuple)): + return len(select_feature) + else: + return 1 + + def get_feat(self, x): + with torch.no_grad(): + x = self.feature_extract([sample for sample in x]) + + x = x["hidden_states"] + if isinstance(x, (list, tuple)): + x = torch.stack(x, dim=0) + else: + x = x.unsqueeze(0) + norm_weights = ( + F.softmax(self.feature_weight, dim=-1) + .unsqueeze(-1) + .unsqueeze(-1) + .unsqueeze(-1) + ) + x = (norm_weights * x).sum(dim=0) + x = torch.transpose(x, 1, 2) + 1e-6 + + x = self.instance_norm(x) + return x + + def forward(self, x): + x = self.get_feat(x) + + out1 = self.layer1(x) + out2 = self.layer2(out1) + out3 = self.layer3(out2) + out4 = self.layer4(out3) + + out = torch.cat([out2, out3, out4], dim=1) + out = F.relu(self.conv(out)) + out = self.bn(self.pooling(out)) + out = self.linear(out) + + return out + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + SIM = SpeakerSimilarity( + sv_model_path=args.sv_model_path, ssl_model_path=args.ssl_model_path + ) + score = SIM.score(args.eval_path, args.test_list) + logging.info(f"SIM-o score: {score:.3f}") diff --git a/egs/zipvoice/local/evaluate_utmos.py b/egs/zipvoice/local/evaluate_utmos.py new file mode 100644 index 000000000..369e139c1 --- /dev/null +++ b/egs/zipvoice/local/evaluate_utmos.py @@ -0,0 +1,294 @@ +""" +Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system +adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo + +# Download model checkpoints +wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt +wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt +""" + +import argparse +import logging +import os + +import fairseq +import librosa +import numpy as np +import pytorch_lightning as pl +import soundfile as sf +import torch +import torch.nn as nn +from tqdm import tqdm + +logging.basicConfig(level=logging.INFO) + + +def get_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--wav-path", type=str, help="path of the evaluated speech directory" + ) + parser.add_argument( + "--utmos-model-path", + type=str, + default="model/huggingface/utmos/utmos.pt", + help="path of the UTMOS model", + ) + parser.add_argument( + "--ssl-model-path", + type=str, + default="model/huggingface/utmos/wav2vec_small.pt", + help="path of the wav2vec SSL model", + ) + return parser + + +class UTMOSScore: + """Predicting score for each audio clip.""" + + def __init__(self, utmos_model_path, ssl_model_path): + self.sample_rate = 16000 + self.device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + self.model = ( + BaselineLightningModule.load_from_checkpoint( + utmos_model_path, ssl_model_path=ssl_model_path + ) + .eval() + .to(self.device) + ) + + def score(self, wavs: torch.Tensor) -> torch.Tensor: + """ + Args: + wavs: waveforms to be evaluated. When len(wavs) == 1 or 2, + the model processes the input as a single audio clip. The model + performs batch processing when len(wavs) == 3. + """ + if len(wavs.shape) == 1: + out_wavs = wavs.unsqueeze(0).unsqueeze(0) + elif len(wavs.shape) == 2: + out_wavs = wavs.unsqueeze(0) + elif len(wavs.shape) == 3: + out_wavs = wavs + else: + raise ValueError("Dimension of input tensor needs to be <= 3.") + bs = out_wavs.shape[0] + batch = { + "wav": out_wavs, + "domains": torch.zeros(bs, dtype=torch.int).to(self.device), + "judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288, + } + with torch.no_grad(): + output = self.model(batch) + + return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3 + + def score_dir(self, dir, dtype="float32"): + def _load_speech_task(fname, sample_rate): + + wav_data, sr = sf.read(fname, dtype=dtype) + if sr != sample_rate: + wav_data = librosa.resample( + wav_data, orig_sr=sr, target_sr=self.sample_rate + ) + wav_data = torch.from_numpy(wav_data) + + return wav_data + + score_lst = [] + for fname in tqdm(os.listdir(dir)): + speech = _load_speech_task(os.path.join(dir, fname), self.sample_rate) + speech = speech.to(self.device) + with torch.no_grad(): + score = self.score(speech) + score_lst.append(score.item()) + return np.mean(score_lst) + + +def load_ssl_model(ckpt_path="wav2vec_small.pt"): + SSL_OUT_DIM = 768 + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [ckpt_path] + ) + ssl_model = model[0] + ssl_model.remove_pretraining_modules() + return SSL_model(ssl_model, SSL_OUT_DIM) + + +class BaselineLightningModule(pl.LightningModule): + def __init__(self, ssl_model_path): + super().__init__() + self.construct_model(ssl_model_path) + self.save_hyperparameters() + + def construct_model(self, ssl_model_path): + self.feature_extractors = nn.ModuleList( + [ + load_ssl_model(ckpt_path=ssl_model_path), + DomainEmbedding(3, 128), + ] + ) + output_dim = sum( + [ + feature_extractor.get_output_dim() + for feature_extractor in self.feature_extractors + ] + ) + output_layers = [ + LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim) + ] + output_dim = output_layers[-1].get_output_dim() + output_layers.append( + Projection( + hidden_dim=2048, + activation=torch.nn.ReLU(), + range_clipping=False, + input_dim=output_dim, + ) + ) + + self.output_layers = nn.ModuleList(output_layers) + + def forward(self, inputs): + outputs = {} + for feature_extractor in self.feature_extractors: + outputs.update(feature_extractor(inputs)) + x = outputs + for output_layer in self.output_layers: + x = output_layer(x, inputs) + return x + + +class SSL_model(nn.Module): + def __init__(self, ssl_model, ssl_out_dim) -> None: + super(SSL_model, self).__init__() + self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim + + def forward(self, batch): + wav = batch["wav"] + wav = wav.squeeze(1) # [batches, wav_len] + res = self.ssl_model(wav, mask=False, features_only=True) + x = res["x"] + return {"ssl-feature": x} + + def get_output_dim(self): + return self.ssl_out_dim + + +class DomainEmbedding(nn.Module): + def __init__(self, n_domains, domain_dim) -> None: + super().__init__() + self.embedding = nn.Embedding(n_domains, domain_dim) + self.output_dim = domain_dim + + def forward(self, batch): + return {"domain-feature": self.embedding(batch["domains"])} + + def get_output_dim(self): + return self.output_dim + + +class LDConditioner(nn.Module): + """ + Conditions ssl output by listener embedding + """ + + def __init__(self, input_dim, judge_dim, num_judges=None): + super().__init__() + self.input_dim = input_dim + self.judge_dim = judge_dim + self.num_judges = num_judges + assert num_judges != None + self.judge_embedding = nn.Embedding(num_judges, self.judge_dim) + # concat [self.output_layer, phoneme features] + + self.decoder_rnn = nn.LSTM( + input_size=self.input_dim + self.judge_dim, + hidden_size=512, + num_layers=1, + batch_first=True, + bidirectional=True, + ) # linear? + self.out_dim = self.decoder_rnn.hidden_size * 2 + + def get_output_dim(self): + return self.out_dim + + def forward(self, x, batch): + judge_ids = batch["judge_id"] + if "phoneme-feature" in x.keys(): + concatenated_feature = torch.cat( + ( + x["ssl-feature"], + x["phoneme-feature"] + .unsqueeze(1) + .expand(-1, x["ssl-feature"].size(1), -1), + ), + dim=2, + ) + else: + concatenated_feature = x["ssl-feature"] + if "domain-feature" in x.keys(): + concatenated_feature = torch.cat( + ( + concatenated_feature, + x["domain-feature"] + .unsqueeze(1) + .expand(-1, concatenated_feature.size(1), -1), + ), + dim=2, + ) + if judge_ids != None: + concatenated_feature = torch.cat( + ( + concatenated_feature, + self.judge_embedding(judge_ids) + .unsqueeze(1) + .expand(-1, concatenated_feature.size(1), -1), + ), + dim=2, + ) + decoder_output, (h, c) = self.decoder_rnn(concatenated_feature) + return decoder_output + + +class Projection(nn.Module): + def __init__(self, input_dim, hidden_dim, activation, range_clipping=False): + super(Projection, self).__init__() + self.range_clipping = range_clipping + output_dim = 1 + if range_clipping: + self.proj = nn.Tanh() + + self.net = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + activation, + nn.Dropout(0.3), + nn.Linear(hidden_dim, output_dim), + ) + self.output_dim = output_dim + + def forward(self, x, batch): + output = self.net(x) + + # range clipping + if self.range_clipping: + return self.proj(output) * 2.0 + 3 + else: + return output + + def get_output_dim(self): + return self.output_dim + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + UTMOS = UTMOSScore( + utmos_model_path=args.utmos_model_path, ssl_model_path=args.ssl_model_path + ) + score = UTMOS.score_dir(args.wav_path) + logging.info(f"UTMOS score: {score:.2f}") diff --git a/egs/zipvoice/local/evaluate_wer_hubert.py b/egs/zipvoice/local/evaluate_wer_hubert.py new file mode 100644 index 000000000..d30346e67 --- /dev/null +++ b/egs/zipvoice/local/evaluate_wer_hubert.py @@ -0,0 +1,172 @@ +""" +Calculate WER with Hubert models. +""" +import argparse +import os +import re +from pathlib import Path + +import librosa +import numpy as np +import soundfile as sf +import torch +from jiwer import compute_measures +from tqdm import tqdm +from transformers import pipeline + + +def get_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument("--wav-path", type=str, help="path of the speech directory") + parser.add_argument( + "--decode-path", + type=str, + default=None, + help="path of the output file of WER information", + ) + parser.add_argument( + "--model-path", + type=str, + default=None, + help="path of the local hubert model, e.g., model/huggingface/hubert-large-ls960-ft", + ) + parser.add_argument( + "--test-list", + type=str, + default="test.tsv", + help="path of the transcript tsv file, where the first column " + "is the wav name and the last column is the transcript", + ) + parser.add_argument( + "--batch-size", type=int, default=16, help="decoding batch size" + ) + return parser + + +def post_process(text: str): + text = text.replace("‘", "'") + text = text.replace("’", "'") + text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower()) + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def process_one(hypo, truth): + truth = post_process(truth) + hypo = post_process(hypo) + + measures = compute_measures(truth, hypo) + word_num = len(truth.split(" ")) + wer = measures["wer"] + subs = measures["substitutions"] + dele = measures["deletions"] + inse = measures["insertions"] + return (truth, hypo, wer, subs, dele, inse, word_num) + + +class SpeechEvalDataset(torch.utils.data.Dataset): + def __init__(self, wav_path: str, test_list: str): + super().__init__() + self.wav_name = [] + self.wav_paths = [] + self.transcripts = [] + with Path(test_list).open("r", encoding="utf8") as f: + meta = [item.split("\t") for item in f.read().rstrip().split("\n")] + for item in meta: + self.wav_name.append(item[0]) + self.wav_paths.append(Path(wav_path, item[0] + ".wav")) + self.transcripts.append(item[-1]) + + def __len__(self): + return len(self.wav_paths) + + def __getitem__(self, index: int): + wav, sampling_rate = sf.read(self.wav_paths[index]) + item = { + "array": librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000), + "sampling_rate": 16000, + "reference": self.transcripts[index], + "wav_name": self.wav_name[index], + } + return item + + +def main(test_list, wav_path, model_path, decode_path, batch_size, device): + + if model_path is not None: + pipe = pipeline( + "automatic-speech-recognition", + model=model_path, + device=device, + tokenizer=model_path, + ) + else: + pipe = pipeline( + "automatic-speech-recognition", + model="facebook/hubert-large-ls960-ft", + device=device, + ) + + dataset = SpeechEvalDataset(wav_path, test_list) + + bar = tqdm( + pipe( + dataset, + generate_kwargs={"language": "english", "task": "transcribe"}, + batch_size=batch_size, + ), + total=len(dataset), + ) + + wers = [] + inses = [] + deles = [] + subses = [] + word_nums = 0 + if decode_path: + decode_dir = os.path.dirname(decode_path) + if not os.path.exists(decode_dir): + os.makedirs(decode_dir) + fout = open(decode_path, "w") + for out in bar: + wav_name = out["wav_name"][0] + transcription = post_process(out["text"].strip()) + text_ref = post_process(out["reference"][0].strip()) + truth, hypo, wer, subs, dele, inse, word_num = process_one( + transcription, text_ref + ) + if decode_path: + fout.write(f"{wav_name}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n") + wers.append(float(wer)) + inses.append(float(inse)) + deles.append(float(dele)) + subses.append(float(subs)) + word_nums += word_num + + wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3) + subs = round(np.mean(subses) * 100, 3) + dele = round(np.mean(deles) * 100, 3) + inse = round(np.mean(inses) * 100, 3) + print(f"WER: {wer}%\n") + if decode_path: + fout.write(f"WER: {wer}%\n") + fout.flush() + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + else: + device = torch.device("cpu") + main( + args.test_list, + args.wav_path, + args.model_path, + args.decode_path, + args.batch_size, + device, + ) diff --git a/egs/zipvoice/local/evaluate_wer_seedtts.py b/egs/zipvoice/local/evaluate_wer_seedtts.py new file mode 100644 index 000000000..f7e256387 --- /dev/null +++ b/egs/zipvoice/local/evaluate_wer_seedtts.py @@ -0,0 +1,181 @@ +""" +Calculate WER with Whisper-large-v3 or Paraformer models, +following Seed-TTS https://github.com/BytedanceSpeech/seed-tts-eval +""" + +import argparse +import os +import string + +import numpy as np +import scipy +import soundfile as sf +import torch +import zhconv +from funasr import AutoModel +from jiwer import compute_measures +from tqdm import tqdm +from transformers import WhisperForConditionalGeneration, WhisperProcessor +from zhon.hanzi import punctuation + + +def get_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument("--wav-path", type=str, help="path of the speech directory") + parser.add_argument( + "--decode-path", + type=str, + default=None, + help="path of the output file of WER information", + ) + parser.add_argument( + "--model-path", + type=str, + default=None, + help="path of the local whisper and paraformer model, " + "e.g., whisper: model/huggingface/whisper-large-v3/, " + "paraformer: model/huggingface/paraformer-zh/", + ) + parser.add_argument( + "--test-list", + type=str, + default="test.tsv", + help="path of the transcript tsv file, where the first column " + "is the wav name and the last column is the transcript", + ) + parser.add_argument("--lang", type=str, help="decoded language, zh or en") + return parser + + +def load_en_model(model_path): + if model_path is None: + model_path = "openai/whisper-large-v3" + processor = WhisperProcessor.from_pretrained(model_path) + model = WhisperForConditionalGeneration.from_pretrained(model_path) + return processor, model + + +def load_zh_model(model_path): + if model_path is None: + model_path = "paraformer-zh" + model = AutoModel(model=model_path) + return model + + +def process_one(hypo, truth, lang): + punctuation_all = punctuation + string.punctuation + for x in punctuation_all: + if x == "'": + continue + truth = truth.replace(x, "") + hypo = hypo.replace(x, "") + + truth = truth.replace(" ", " ") + hypo = hypo.replace(" ", " ") + + if lang == "zh": + truth = " ".join([x for x in truth]) + hypo = " ".join([x for x in hypo]) + elif lang == "en": + truth = truth.lower() + hypo = hypo.lower() + else: + raise NotImplementedError + + measures = compute_measures(truth, hypo) + word_num = len(truth.split(" ")) + wer = measures["wer"] + subs = measures["substitutions"] + dele = measures["deletions"] + inse = measures["insertions"] + return (truth, hypo, wer, subs, dele, inse, word_num) + + +def main(test_list, wav_path, model_path, decode_path, lang, device): + if lang == "en": + processor, model = load_en_model(model_path) + model.to(device) + elif lang == "zh": + model = load_zh_model(model_path) + params = [] + for line in open(test_list).readlines(): + line = line.strip() + items = line.split("\t") + wav_name, text_ref = items[0], items[-1] + file_path = os.path.join(wav_path, wav_name + ".wav") + assert os.path.exists(file_path), f"{file_path}" + + params.append((file_path, text_ref)) + wers = [] + inses = [] + deles = [] + subses = [] + word_nums = 0 + if decode_path: + decode_dir = os.path.dirname(decode_path) + if not os.path.exists(decode_dir): + os.makedirs(decode_dir) + fout = open(decode_path, "w") + for wav_path, text_ref in tqdm(params): + if lang == "en": + wav, sr = sf.read(wav_path) + if sr != 16000: + wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr)) + input_features = processor( + wav, sampling_rate=16000, return_tensors="pt" + ).input_features + input_features = input_features.to(device) + forced_decoder_ids = processor.get_decoder_prompt_ids( + language="english", task="transcribe" + ) + predicted_ids = model.generate( + input_features, forced_decoder_ids=forced_decoder_ids + ) + transcription = processor.batch_decode( + predicted_ids, skip_special_tokens=True + )[0] + elif lang == "zh": + res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True) + transcription = res[0]["text"] + transcription = zhconv.convert(transcription, "zh-cn") + + truth, hypo, wer, subs, dele, inse, word_num = process_one( + transcription, text_ref, lang + ) + if decode_path: + fout.write(f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n") + wers.append(float(wer)) + inses.append(float(inse)) + deles.append(float(dele)) + subses.append(float(subs)) + word_nums += word_num + + wer_avg = round(np.mean(wers) * 100, 3) + wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3) + subs = round(np.mean(subses) * 100, 3) + dele = round(np.mean(deles) * 100, 3) + inse = round(np.mean(inses) * 100, 3) + print(f"Seed-TTS WER: {wer_avg}%\n") + print(f"WER: {wer}%\n") + if decode_path: + fout.write(f"SeedTTS WER: {wer_avg}%\n") + fout.write(f"WER: {wer}%\n") + fout.flush() + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + else: + device = torch.device("cpu") + main( + args.test_list, + args.wav_path, + args.model_path, + args.decode_path, + args.lang, + device, + ) diff --git a/egs/zipvoice/local/feature.py b/egs/zipvoice/local/feature.py new file mode 100644 index 000000000..e7d484d10 --- /dev/null +++ b/egs/zipvoice/local/feature.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torchaudio +from lhotse.features.base import FeatureExtractor, register_extractor +from lhotse.utils import Seconds, compute_num_frames + + +class MelSpectrogramFeatures(nn.Module): + def __init__( + self, + sampling_rate=24000, + n_mels=100, + n_fft=1024, + hop_length=256, + ): + super().__init__() + + self.mel_spec = torchaudio.transforms.MelSpectrogram( + sample_rate=sampling_rate, + n_fft=n_fft, + hop_length=hop_length, + n_mels=n_mels, + center=True, + power=1, + ) + + def forward(self, inp): + assert len(inp.shape) == 2 + + mel = self.mel_spec(inp) + logmel = mel.clamp(min=1e-7).log() + return logmel + + +@dataclass +class TorchAudioFbankConfig: + sampling_rate: int + n_mels: int + n_fft: int + hop_length: int + + +@register_extractor +class TorchAudioFbank(FeatureExtractor): + + name = "TorchAudioFbank" + config_type = TorchAudioFbankConfig + + def __init__(self, config): + super().__init__(config=config) + + def _feature_fn(self, sample): + fbank = MelSpectrogramFeatures( + sampling_rate=self.config.sampling_rate, + n_mels=self.config.n_mels, + n_fft=self.config.n_fft, + hop_length=self.config.hop_length, + ) + + return fbank(sample) + + @property + def device(self) -> Union[str, torch.device]: + return self.config.device + + def feature_dim(self, sampling_rate: int) -> int: + return self.config.n_mels + + def extract( + self, + samples: Union[np.ndarray, torch.Tensor], + sampling_rate: int, + ) -> Union[np.ndarray, torch.Tensor]: + # Check for sampling rate compatibility. + expected_sr = self.config.sampling_rate + assert sampling_rate == expected_sr, ( + f"Mismatched sampling rate: extractor expects {expected_sr}, " + f"got {sampling_rate}" + ) + is_numpy = False + if not isinstance(samples, torch.Tensor): + samples = torch.from_numpy(samples) + is_numpy = True + + if len(samples.shape) == 1: + samples = samples.unsqueeze(0) + assert samples.ndim == 2, samples.shape + assert samples.shape[0] == 1, samples.shape + + mel = self._feature_fn(samples).squeeze().t() + + assert mel.ndim == 2, mel.shape + assert mel.shape[1] == self.config.n_mels, mel.shape + + num_frames = compute_num_frames( + samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate + ) + + if mel.shape[0] > num_frames: + mel = mel[:num_frames] + elif mel.shape[0] < num_frames: + mel = mel.unsqueeze(0) + mel = torch.nn.functional.pad( + mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" + ).squeeze(0) + + if is_numpy: + return mel.cpu().numpy() + else: + return mel + + @property + def frame_shift(self) -> Seconds: + return self.config.hop_length / self.config.sampling_rate diff --git a/egs/zipvoice/local/prepare_libritts.sh b/egs/zipvoice/local/prepare_libritts.sh new file mode 100755 index 000000000..b35065bb1 --- /dev/null +++ b/egs/zipvoice/local/prepare_libritts.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=0 +stop_stage=5 +sampling_rate=24000 +nj=32 + +dl_dir=$PWD/download + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/LibriTTS, + # you can create a symlink + # + # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS + # + if [ ! -d $dl_dir/LibriTTS ]; then + lhotse download libritts $dl_dir + fi + +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LibriTTS manifest" + # We assume that you have downloaded the LibriTTS corpus + # to $dl_dir/LibriTTS + mkdir -p data/manifests_libritts + if [ ! -e data/manifests_libritts/.libritts.done ]; then + lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests_libritts + touch data/manifests_libritts/.libritts.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute Fbank for LibriTTS" + mkdir -p data/fbank + if [ ! -e data/fbank_libritts/.libritts.done ]; then + ./local/compute_fbank_libritts.py --sampling-rate $sampling_rate + touch data/fbank_libritts/.libritts.done + fi + + # Here we shuffle and combine the train-clean-100, train-clean-360 and + # train-other-500 together to form the training set. + if [ ! -f data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz ]; then + cat <(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c data/fbank_libritts/libritts_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz + fi + + if [ ! -f data/fbank_libritts/libritts_cuts_train-clean-460.jsonl.gz ]; then + cat <(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-360.jsonl.gz) | \ + shuf | gzip -c > data/fbank_libritts/libritts_cuts_train-clean-460.jsonl.gz + fi + + if [ ! -e data/fbank_libritts/.libritts-validated.done ]; then + log "Validating data/fbank for LibriTTS" + ./local/validate_manifest.py \ + data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz + touch data/fbank_libritts/.libritts-validated.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 4: Generate token file" + if [ ! -e data/tokens_libritts.txt ]; then + ./local/prepare_token_file_libritts.py --tokens data/tokens_libritts.txt + fi +fi \ No newline at end of file diff --git a/egs/zipvoice/local/prepare_token_file_emilia.py b/egs/zipvoice/local/prepare_token_file_emilia.py new file mode 100644 index 000000000..68af8d397 --- /dev/null +++ b/egs/zipvoice/local/prepare_token_file_emilia.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file generates the file that maps tokens to IDs. +""" + +import argparse +import logging +from pathlib import Path +from typing import List + +from piper_phonemize import get_espeak_map +from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--tokens", + type=Path, + default=Path("data/tokens_emilia.txt"), + help="Path to the dict that maps the text tokens to IDs", + ) + + parser.add_argument( + "--pinyin", + type=Path, + default=Path("local/pinyin.txt"), + help="Path to the all unique pinyin", + ) + + return parser.parse_args() + + +def get_pinyin_tokens(pinyin: Path) -> List[str]: + phones = set() + with open(pinyin, "r") as f: + for line in f: + x = line.strip() + initial = to_initials(x, strict=False) + # don't want to share tokens with espeak tokens, so use tone3 style + finals = to_finals_tone3(x, strict=False, neutral_tone_with_five=True) + if initial != "": + # don't want to share tokens with espeak tokens, so add a '0' after each initial + phones.add(initial + "0") + if finals != "": + phones.add(finals) + return sorted(phones) + + +def get_token2id(args): + """Get a dict that maps token to IDs, and save it to the given filename.""" + all_tokens = get_espeak_map() # token: [token_id] + all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()} + # sort by token_id + all_tokens = sorted(all_tokens.items(), key=lambda x: x[1]) + + all_pinyin = get_pinyin_tokens(args.pinyin) + with open(args.tokens, "w", encoding="utf-8") as f: + for token, token_id in all_tokens: + f.write(f"{token} {token_id}\n") + num_espeak_tokens = len(all_tokens) + for i, pinyin in enumerate(all_pinyin): + f.write(f"{pinyin} {num_espeak_tokens + i}\n") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + get_token2id(args) diff --git a/egs/zipvoice/local/prepare_token_file_libritts.py b/egs/zipvoice/local/prepare_token_file_libritts.py new file mode 100644 index 000000000..374b02613 --- /dev/null +++ b/egs/zipvoice/local/prepare_token_file_libritts.py @@ -0,0 +1,31 @@ +import re +from collections import Counter + +from lhotse import load_manifest_lazy + + +def prepare_tokens(manifest_file, token_file): + counter = Counter() + manifest = load_manifest_lazy(manifest_file) + for cut in manifest: + line = re.sub(r"\s+", " ", cut.supervisions[0].text) + counter.update(line) + + unique_chars = set(counter.keys()) + + if "_" in unique_chars: + unique_chars.remove("_") + + sorted_chars = sorted(unique_chars, key=lambda char: counter[char], reverse=True) + + result = ["_"] + sorted_chars + + with open(token_file, "w", encoding="utf-8") as file: + for index, char in enumerate(result): + file.write(f"{char} {index}\n") + + +if __name__ == "__main__": + manifest_file = "data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz" + output_token_file = "data/tokens_libritts.txt" + prepare_tokens(manifest_file, output_token_file) diff --git a/egs/zipvoice/local/prepare_tokens_emilia.py b/egs/zipvoice/local/prepare_tokens_emilia.py new file mode 100644 index 000000000..023d57524 --- /dev/null +++ b/egs/zipvoice/local/prepare_tokens_emilia.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao, +# Zengrui Jin, +# Wei Kang) +# 2024 Tsinghua University (authors: Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and save the new cuts with phoneme tokens. +""" + +import argparse +import glob +import logging +import re +from concurrent.futures import ProcessPoolExecutor as Pool +from pathlib import Path +from typing import List + +import jieba +from lhotse import load_manifest_lazy +from tokenizer import Tokenizer, is_alphabet, is_chinese, is_hangul, is_japanese + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--subset", + type=str, + help="Subset of emilia, (ZH, EN, etc.)", + ) + + parser.add_argument( + "--jobs", + type=int, + default=50, + help="Number of jobs to processing.", + ) + + parser.add_argument( + "--source-dir", + type=str, + default="data/manifests_emilia/splits", + help="The source directory of manifest files.", + ) + + parser.add_argument( + "--dest-dir", + type=str, + help="The destination directory of manifest files.", + ) + + return parser.parse_args() + + +def tokenize_by_CJK_char(line: str) -> List[str]: + """ + Tokenize a line of text with CJK char. + + Note: All return characters will be upper case. + + Example: + input = "你好世界是 hello world 的中文" + output = [你, 好, 世, 界, 是, HELLO, WORLD, 的, 中, 文] + + Args: + line: + The input text. + + Return: + A new string tokenize by CJK char. + """ + # The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py + pattern = re.compile( + r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])" + ) + chars = pattern.split(line.strip().upper()) + char_list = [] + for w in chars: + if w.strip(): + char_list += w.strip().split() + return char_list + + +def prepare_tokens_emilia(file_name: str, input_dir: Path, output_dir: Path): + logging.info(f"Processing {file_name}") + if (output_dir / file_name).is_file(): + logging.info(f"{file_name} exists, skipping.") + return + jieba.setLogLevel(logging.INFO) + tokenizer = Tokenizer() + + def _prepare_cut(cut): + # Each cut only contains one supervision + assert len(cut.supervisions) == 1, (len(cut.supervisions), cut) + text = cut.supervisions[0].text + cut.supervisions[0].normalized_text = text + tokens = tokenizer.texts_to_tokens([text])[0] + cut.tokens = tokens + return cut + + def _filter_cut(cut): + text = cut.supervisions[0].text + duration = cut.supervisions[0].duration + chinese = [] + english = [] + + # only contains chinese and space and alphabets + clean_chars = [] + for x in text: + if is_hangul(x): + logging.info(f"Delete cut with text containing Korean : {text}") + return False + if is_japanese(x): + logging.info(f"Delete cut with text containing Japanese : {text}") + return False + if is_chinese(x): + chinese.append(x) + clean_chars.append(x) + if is_alphabet(x): + english.append(x) + clean_chars.append(x) + if x == " ": + clean_chars.append(x) + if len(english) + len(chinese) == 0: + logging.info(f"Delete cut with text has no valid chars : {text}") + return False + + words = tokenize_by_CJK_char("".join(clean_chars)) + for i in range(len(words) - 10): + if words[i : i + 10].count(words[i]) == 10: + logging.info(f"Delete cut with text with too much repeats : {text}") + return False + # word speed, 20 - 600 / minute + if duration < len(words) / 600 * 60 or duration > len(words) / 20 * 60: + logging.info( + f"Delete cut with audio text mismatch, duration : {duration}s, words : {len(words)}, text : {text}" + ) + return False + return True + + try: + cut_set = load_manifest_lazy(input_dir / file_name) + cut_set = cut_set.filter(_filter_cut) + cut_set = cut_set.map(_prepare_cut) + cut_set.to_file(output_dir / file_name) + except Exception as e: + logging.error(f"Manifest {file_name} failed with error: {e}") + raise + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + + input_dir = Path(args.source_dir) + output_dir = Path(args.dest_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + cut_files = glob.glob(f"{args.source_dir}/emilia_cuts_{args.subset}.*.jsonl.gz") + + with Pool(max_workers=args.jobs) as pool: + futures = [ + pool.submit( + prepare_tokens_emilia, filename.split("/")[-1], input_dir, output_dir + ) + for filename in cut_files + ] + for f in futures: + try: + f.result() + f.done() + except Exception as e: + logging.error(f"Future failed with error: {e}") + logging.info("Processing done.") diff --git a/egs/zipvoice/local/validate_manifest.py b/egs/zipvoice/local/validate_manifest.py new file mode 100755 index 000000000..68159ae03 --- /dev/null +++ b/egs/zipvoice/local/validate_manifest.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/spectrogram/ljspeech_cuts_all.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset.speech_synthesis import validate_for_tts + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet), type(cut_set) + + validate_for_tts(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/zipvoice/zipvoice/checkpoint.py b/egs/zipvoice/zipvoice/checkpoint.py new file mode 100644 index 000000000..e3acd57dd --- /dev/null +++ b/egs/zipvoice/zipvoice/checkpoint.py @@ -0,0 +1,142 @@ +# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import torch +import torch.nn as nn +from lhotse.dataset.sampling.base import CutSampler +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer + +# use duck typing for LRScheduler since we have different possibilities, see +# our class LRScheduler. +LRSchedulerType = object + + +def save_checkpoint( + filename: Path, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + model_ema: Optional[nn.Module] = None, + params: Optional[Dict[str, Any]] = None, + optimizer: Optional[Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + rank: int = 0, +) -> None: + """Save training information to a file. + + Args: + filename: + The checkpoint filename. + model: + The model to be saved. We only save its `state_dict()`. + model_avg: + The stored model averaged from the start of training. + model_ema: + The EMA version of model. + params: + User defined parameters, e.g., epoch, loss. + optimizer: + The optimizer to be saved. We only save its `state_dict()`. + scheduler: + The scheduler to be saved. We only save its `state_dict()`. + scalar: + The GradScaler to be saved. We only save its `state_dict()`. + sampler: + The sampler used in the labeled training dataset. We only + save its `state_dict()`. + rank: + Used in DDP. We save checkpoint only for the node whose + rank is 0. + Returns: + Return None. + """ + if rank != 0: + return + + logging.info(f"Saving checkpoint to {filename}") + + if isinstance(model, DDP): + model = model.module + + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict() if optimizer is not None else None, + "scheduler": scheduler.state_dict() if scheduler is not None else None, + "grad_scaler": scaler.state_dict() if scaler is not None else None, + "sampler": sampler.state_dict() if sampler is not None else None, + } + + if model_avg is not None: + checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict() + if model_ema is not None: + checkpoint["model_ema"] = model_ema.to(torch.float32).state_dict() + + if params: + for k, v in params.items(): + assert k not in checkpoint + checkpoint[k] = v + + torch.save(checkpoint, filename) + + +def load_checkpoint( + filename: Path, + model: Optional[nn.Module] = None, + model_avg: Optional[nn.Module] = None, + model_ema: Optional[nn.Module] = None, + strict: bool = False, +) -> Dict[str, Any]: + logging.info(f"Loading checkpoint from {filename}") + checkpoint = torch.load(filename, map_location="cpu") + + if model is not None: + + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + logging.info("Loading checkpoint") + model.load_state_dict(checkpoint["model"], strict=strict) + + checkpoint.pop("model") + + if model_avg is not None and "model_avg" in checkpoint: + logging.info("Loading averaged model") + model_avg.load_state_dict(checkpoint["model_avg"], strict=strict) + checkpoint.pop("model_avg") + + if model_ema is not None and "model_ema" in checkpoint: + logging.info("Loading ema model") + model_ema.load_state_dict(checkpoint["model_ema"], strict=strict) + checkpoint.pop("model_ema") + + return checkpoint diff --git a/egs/zipvoice/zipvoice/feature.py b/egs/zipvoice/zipvoice/feature.py new file mode 100644 index 000000000..e7d484d10 --- /dev/null +++ b/egs/zipvoice/zipvoice/feature.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torchaudio +from lhotse.features.base import FeatureExtractor, register_extractor +from lhotse.utils import Seconds, compute_num_frames + + +class MelSpectrogramFeatures(nn.Module): + def __init__( + self, + sampling_rate=24000, + n_mels=100, + n_fft=1024, + hop_length=256, + ): + super().__init__() + + self.mel_spec = torchaudio.transforms.MelSpectrogram( + sample_rate=sampling_rate, + n_fft=n_fft, + hop_length=hop_length, + n_mels=n_mels, + center=True, + power=1, + ) + + def forward(self, inp): + assert len(inp.shape) == 2 + + mel = self.mel_spec(inp) + logmel = mel.clamp(min=1e-7).log() + return logmel + + +@dataclass +class TorchAudioFbankConfig: + sampling_rate: int + n_mels: int + n_fft: int + hop_length: int + + +@register_extractor +class TorchAudioFbank(FeatureExtractor): + + name = "TorchAudioFbank" + config_type = TorchAudioFbankConfig + + def __init__(self, config): + super().__init__(config=config) + + def _feature_fn(self, sample): + fbank = MelSpectrogramFeatures( + sampling_rate=self.config.sampling_rate, + n_mels=self.config.n_mels, + n_fft=self.config.n_fft, + hop_length=self.config.hop_length, + ) + + return fbank(sample) + + @property + def device(self) -> Union[str, torch.device]: + return self.config.device + + def feature_dim(self, sampling_rate: int) -> int: + return self.config.n_mels + + def extract( + self, + samples: Union[np.ndarray, torch.Tensor], + sampling_rate: int, + ) -> Union[np.ndarray, torch.Tensor]: + # Check for sampling rate compatibility. + expected_sr = self.config.sampling_rate + assert sampling_rate == expected_sr, ( + f"Mismatched sampling rate: extractor expects {expected_sr}, " + f"got {sampling_rate}" + ) + is_numpy = False + if not isinstance(samples, torch.Tensor): + samples = torch.from_numpy(samples) + is_numpy = True + + if len(samples.shape) == 1: + samples = samples.unsqueeze(0) + assert samples.ndim == 2, samples.shape + assert samples.shape[0] == 1, samples.shape + + mel = self._feature_fn(samples).squeeze().t() + + assert mel.ndim == 2, mel.shape + assert mel.shape[1] == self.config.n_mels, mel.shape + + num_frames = compute_num_frames( + samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate + ) + + if mel.shape[0] > num_frames: + mel = mel[:num_frames] + elif mel.shape[0] < num_frames: + mel = mel.unsqueeze(0) + mel = torch.nn.functional.pad( + mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" + ).squeeze(0) + + if is_numpy: + return mel.cpu().numpy() + else: + return mel + + @property + def frame_shift(self) -> Seconds: + return self.config.hop_length / self.config.sampling_rate diff --git a/egs/zipvoice/zipvoice/generate_averaged_model.py b/egs/zipvoice/zipvoice/generate_averaged_model.py new file mode 100755 index 000000000..e1b7ca7c6 --- /dev/null +++ b/egs/zipvoice/zipvoice/generate_averaged_model.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +This script loads checkpoints and averages them. + +(1) Average ZipVoice models before distill: + python3 ./zipvoice/generate_averaged_model.py \ + --epoch 11 \ + --avg 4 \ + --distill 0 \ + --token-file data/tokens_emilia.txt \ + --exp-dir ./zipvoice/exp_zipvoice + + It will generate a file `epoch-11-avg-14.pt` in the given `exp_dir`. + You can later load it by `torch.load("epoch-11-avg-4.pt")`. + +(2) Average ZipVoice-Distill models (the first stage model): + + python3 ./zipvoice/generate_averaged_model.py \ + --iter 60000 \ + --avg 7 \ + --distill 1 \ + --token-file data/tokens_emilia.txt \ + --exp-dir ./zipvoice/exp_zipvoice_distill_1stage +""" + +import argparse +from pathlib import Path + +import torch +from model import get_distill_model, get_model +from tokenizer import TokenizerEmilia, TokenizerLibriTTS +from train_flow import add_model_arguments, get_params + +from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=11, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=4, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' or --iter", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipvoice/exp_zipvoice", + help="The experiment dir", + ) + + parser.add_argument( + "--distill", + type=str2bool, + default=False, + help="Whether to use distill model. ", + ) + + parser.add_argument( + "--dataset", + type=str, + default="emilia", + choices=["emilia", "libritts"], + help="The used training dataset for the model to inference", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if params.dataset == "emilia": + tokenizer = TokenizerEmilia( + token_file=params.token_file, token_type=params.token_type + ) + elif params.dataset == "libritts": + tokenizer = TokenizerLibriTTS( + token_file=params.token_file, token_type=params.token_type + ) + + params.vocab_size = tokenizer.vocab_size + params.pad_id = tokenizer.pad_id + + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + print("Script started") + + params.device = torch.device("cpu") + print(f"Device: {params.device}") + + print("About to create model") + if params.distill: + model = get_distill_model(params) + else: + model = get_model(params) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + print( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(params.device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=params.device, + ), + strict=True, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + print( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(params.device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=params.device, + ), + strict=True, + ) + if params.iter > 0: + filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" + else: + filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save({"model": model.state_dict()}, filename) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/zipvoice/zipvoice/infer.py b/egs/zipvoice/zipvoice/infer.py new file mode 100644 index 000000000..2819d3c85 --- /dev/null +++ b/egs/zipvoice/zipvoice/infer.py @@ -0,0 +1,586 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Wei Kang +# Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script loads checkpoints to generate waveforms. +This script is supposed to be used with the model trained by yourself. +If you want to use the pre-trained checkpoints provided by us, please refer to zipvoice_infer.py. + +(1) Usage with a pre-trained checkpoint: + + (a) ZipVoice model before distill: + python3 zipvoice/infer.py \ + --checkpoint zipvoice/exp_zipvoice/epoch-11-avg-4.pt \ + --distill 0 \ + --token-file "data/tokens_emilia.txt" \ + --test-list test.tsv \ + --res-dir results/test \ + --num-step 16 \ + --guidance-scale 1 + + (b) ZipVoice-Distill: + python3 zipvoice/infer.py \ + --checkpoint zipvoice/exp_zipvoice_distill/checkpoint-2000.pt \ + --distill 1 \ + --token-file "data/tokens_emilia.txt" \ + --test-list test.tsv \ + --res-dir results/test_distill \ + --num-step 8 \ + --guidance-scale 3 + +(2) Usage with a directory of checkpoints (may requires checkpoint averaging): + + (a) ZipVoice model before distill: + python3 flow_match/infer.py \ + --exp-dir zipvoice/exp_zipvoice \ + --epoch 11 \ + --avg 4 \ + --distill 0 \ + --token-file "data/tokens_emilia.txt" \ + --test-list test.tsv \ + --res-dir results \ + --num-step 16 \ + --guidance-scale 1 + + (b) ZipVoice-Distill: + python3 flow_match/infer.py \ + --exp-dir zipvoice/exp_zipvoice_distill/ \ + --iter 2000 \ + --avg 0 \ + --distill 1 \ + --token-file "data/tokens_emilia.txt" \ + --test-list test.tsv \ + --res-dir results \ + --num-step 8 \ + --guidance-scale 3 +""" + + +import argparse +import datetime as dt +import logging +import os +from pathlib import Path +from typing import Optional + +import numpy as np +import soundfile as sf +import torch +import torch.nn as nn +import torchaudio +from checkpoint import load_checkpoint +from feature import TorchAudioFbank, TorchAudioFbankConfig +from lhotse.utils import fix_random_seed +from model import get_distill_model, get_model +from tokenizer import TokenizerEmilia, TokenizerLibriTTS +from train_flow import add_model_arguments, get_params +from vocos import Vocos + +from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints +from icefall.utils import AttributeDict, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="The checkpoint for inference. " + "If it is None, will use checkpoints under exp_dir", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipvoice/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--epoch", + type=int, + default=0, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=4, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' or '--iter', avg=0 means no avg", + ) + + parser.add_argument( + "--vocoder-path", + type=str, + default=None, + help="The local vocos vocoder path, downloaded from huggingface, " + "will download the vocodoer from huggingface if it is None.", + ) + + parser.add_argument( + "--distill", + type=str2bool, + default=False, + help="Whether it is a distilled TTS model.", + ) + + parser.add_argument( + "--test-list", + type=str, + default=None, + help="The list of prompt speech, prompt_transcription, " + "and text to synthesize in the format of " + "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.", + ) + + parser.add_argument( + "--res-dir", + type=str, + default="results", + help="Path name of the generated wavs dir", + ) + + parser.add_argument( + "--dataset", + type=str, + default="emilia", + choices=["emilia", "libritts"], + help="The used training dataset for the model to inference", + ) + + parser.add_argument( + "--guidance-scale", + type=float, + default=1.0, + help="The scale of classifier-free guidance during inference.", + ) + + parser.add_argument( + "--num-step", + type=int, + default=16, + help="The number of sampling steps.", + ) + + parser.add_argument( + "--feat-scale", + type=float, + default=0.1, + help="The scale factor of fbank feature", + ) + + parser.add_argument( + "--speed", + type=float, + default=1.0, + help="Control speech speed, 1.0 means normal, >1.0 means speed up", + ) + + parser.add_argument( + "--t-shift", + type=float, + default=0.5, + help="Shift t to smaller ones if t_shift < 1.0", + ) + + parser.add_argument( + "--target-rms", + type=float, + default=0.1, + help="Target speech normalization rms value", + ) + + parser.add_argument( + "--seed", + type=int, + default=666, + help="Random seed", + ) + + add_model_arguments(parser) + + return parser + + +def get_vocoder(vocos_local_path: Optional[str] = None): + if vocos_local_path: + vocos_local_path = "model/huggingface/vocos-mel-24khz/" + vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") + state_dict = torch.load( + f"{vocos_local_path}/pytorch_model.bin", + weights_only=True, + map_location="cpu", + ) + vocoder.load_state_dict(state_dict) + else: + vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz") + return vocoder + + +def generate_sentence( + save_path: str, + prompt_text: str, + prompt_wav: str, + text: str, + model: nn.Module, + vocoder: nn.Module, + tokenizer: TokenizerEmilia, + feature_extractor: TorchAudioFbank, + device: torch.device, + num_step: int = 16, + guidance_scale: float = 1.0, + speed: float = 1.0, + t_shift: float = 0.5, + target_rms: float = 0.1, + feat_scale: float = 0.1, + sampling_rate: int = 24000, +): + """ + Generate waveform of a text based on a given prompt + waveform and its transcription. + + Args: + save_path (str): Path to save the generated wav. + prompt_text (str): Transcription of the prompt wav. + prompt_wav (str): Path to the prompt wav file. + text (str): Text to be synthesized into a waveform. + model (nn.Module): The model used for generation. + vocoder (nn.Module): The vocoder used to convert features to waveforms. + tokenizer (TokenizerEmilia): The tokenizer used to convert text to tokens. + feature_extractor (TorchAudioFbank): The feature extractor used to + extract acoustic features. + device (torch.device): The device on which computations are performed. + num_step (int, optional): Number of steps for decoding. Defaults to 16. + guidance_scale (float, optional): Scale for classifier-free guidance. + Defaults to 1.0. + speed (float, optional): Speed control. Defaults to 1.0. + t_shift (float, optional): Time shift. Defaults to 0.5. + target_rms (float, optional): Target RMS for waveform normalization. + Defaults to 0.1. + feat_scale (float, optional): Scale for features. + Defaults to 0.1. + sampling_rate (int, optional): Sampling rate for the waveform. + Defaults to 24000. + Returns: + metrics (dict): Dictionary containing time and real-time + factor metrics for processing. + """ + # Convert text to tokens + tokens = tokenizer.texts_to_token_ids([text]) + prompt_tokens = tokenizer.texts_to_token_ids([prompt_text]) + + # Load and preprocess prompt wav + prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav) + prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav))) + if prompt_rms < target_rms: + prompt_wav = prompt_wav * target_rms / prompt_rms + + if prompt_sampling_rate != sampling_rate: + resampler = torchaudio.transforms.Resample( + orig_freq=prompt_sampling_rate, new_freq=sampling_rate + ) + prompt_wav = resampler(prompt_wav) + + # Extract features from prompt wav + prompt_features = feature_extractor.extract( + prompt_wav, sampling_rate=sampling_rate + ).to(device) + prompt_features = prompt_features.unsqueeze(0) * feat_scale + prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device) + + # Start timing + start_t = dt.datetime.now() + + # Generate features + ( + pred_features, + pred_features_lens, + pred_prompt_features, + pred_prompt_features_lens, + ) = model.sample( + tokens=tokens, + prompt_tokens=prompt_tokens, + prompt_features=prompt_features, + prompt_features_lens=prompt_features_lens, + speed=speed, + t_shift=t_shift, + duration="predict", + num_step=num_step, + guidance_scale=guidance_scale, + ) + + # Postprocess predicted features + pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T) + + # Start vocoder processing + start_vocoder_t = dt.datetime.now() + wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1) + + # Calculate processing times and real-time factors + t = (dt.datetime.now() - start_t).total_seconds() + t_no_vocoder = (start_vocoder_t - start_t).total_seconds() + t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds() + wav_seconds = wav.shape[-1] / sampling_rate + rtf = t / wav_seconds + rtf_no_vocoder = t_no_vocoder / wav_seconds + rtf_vocoder = t_vocoder / wav_seconds + metrics = { + "t": t, + "t_no_vocoder": t_no_vocoder, + "t_vocoder": t_vocoder, + "wav_seconds": wav_seconds, + "rtf": rtf, + "rtf_no_vocoder": rtf_no_vocoder, + "rtf_vocoder": rtf_vocoder, + } + + # Adjust wav volume if necessary + if prompt_rms < target_rms: + wav = wav * prompt_rms / target_rms + wav = wav[0].cpu().numpy() + sf.write(save_path, wav, sampling_rate) + + return metrics + + +def generate( + params: AttributeDict, + model: nn.Module, + vocoder: nn.Module, + tokenizer: TokenizerEmilia, +): + total_t = [] + total_t_no_vocoder = [] + total_t_vocoder = [] + total_wav_seconds = [] + + config = TorchAudioFbankConfig( + sampling_rate=params.sampling_rate, + n_mels=100, + n_fft=1024, + hop_length=256, + ) + feature_extractor = TorchAudioFbank(config) + + with open(params.test_list, "r") as fr: + lines = fr.readlines() + + for i, line in enumerate(lines): + wav_name, prompt_text, prompt_wav, text = line.strip().split("\t") + save_path = f"{params.wav_dir}/{wav_name}.wav" + metrics = generate_sentence( + save_path=save_path, + prompt_text=prompt_text, + prompt_wav=prompt_wav, + text=text, + model=model, + vocoder=vocoder, + tokenizer=tokenizer, + feature_extractor=feature_extractor, + device=params.device, + num_step=params.num_step, + guidance_scale=params.guidance_scale, + speed=params.speed, + t_shift=params.t_shift, + target_rms=params.target_rms, + feat_scale=params.feat_scale, + sampling_rate=params.sampling_rate, + ) + print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}") + total_t.append(metrics["t"]) + total_t_no_vocoder.append(metrics["t_no_vocoder"]) + total_t_vocoder.append(metrics["t_vocoder"]) + total_wav_seconds.append(metrics["wav_seconds"]) + + print(f"Average RTF: " f"{np.sum(total_t)/np.sum(total_wav_seconds):.4f}") + print( + f"Average RTF w/o vocoder: " + f"{np.sum(total_t_no_vocoder)/np.sum(total_wav_seconds):.4f}" + ) + print( + f"Average RTF vocoder: " + f"{np.sum(total_t_vocoder)/np.sum(total_wav_seconds):.4f}" + ) + + +@torch.inference_mode() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = ( + f"wavs-iter-{params.iter}-avg" + f"-{params.avg}-step-{params.num_step}-scale-{params.guidance_scale}" + ) + elif params.epoch > 0: + params.suffix = ( + f"wavs-epoch-{params.epoch}-avg" + f"-{params.avg}-step-{params.num_step}-scale-{params.guidance_scale}" + ) + else: + params.suffix = "wavs" + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Decoding started") + + if torch.cuda.is_available(): + params.device = torch.device("cuda", 0) + else: + params.device = torch.device("cpu") + + logging.info(f"Device: {params.device}") + + if params.dataset == "emilia": + tokenizer = TokenizerEmilia( + token_file=params.token_file, token_type=params.token_type + ) + elif params.dataset == "libritts": + tokenizer = TokenizerLibriTTS( + token_file=params.token_file, token_type=params.token_type + ) + + params.vocab_size = tokenizer.vocab_size + params.pad_id = tokenizer.pad_id + + logging.info(params) + fix_random_seed(params.seed) + + logging.info("About to create model") + if params.distill: + model = get_distill_model(params) + else: + model = get_model(params) + + if params.checkpoint: + load_checkpoint(params.checkpoint, model, strict=True) + else: + if params.avg == 0: + if params.iter > 0: + load_checkpoint( + f"{params.exp_dir}/checkpoint-{params.iter}.pt", model, strict=True + ) + else: + load_checkpoint( + f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(params.device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=params.device, + ), + strict=True, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(params.device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=params.device, + ), + strict=True, + ) + + model = model.to(params.device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + vocoder = get_vocoder(params.vocoder_path) + vocoder = vocoder.to(params.device) + vocoder.eval() + num_param = sum([p.numel() for p in vocoder.parameters()]) + logging.info(f"Number of vocoder parameters: {num_param}") + + params.wav_dir = f"{params.res_dir}/{params.suffix}" + os.makedirs(params.wav_dir, exist_ok=True) + + assert ( + params.test_list is not None + ), "Please provide --test-list for speech synthesize." + generate( + params=params, + model=model, + vocoder=vocoder, + tokenizer=tokenizer, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/zipvoice/zipvoice/model.py b/egs/zipvoice/zipvoice/model.py new file mode 100644 index 000000000..25c7973b2 --- /dev/null +++ b/egs/zipvoice/zipvoice/model.py @@ -0,0 +1,578 @@ +# Copyright 2024 Xiaomi Corp. (authors: Wei Kang +# Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +import torch.nn as nn +from scaling import ScheduledFloat +from solver import EulerSolver +from torch.nn.parallel import DistributedDataParallel as DDP +from utils import ( + AttributeDict, + condition_time_mask, + get_tokens_index, + make_pad_mask, + pad_labels, + prepare_avg_tokens_durations, + to_int_tuple, +) +from zipformer import TTSZipformer + + +def get_model(params: AttributeDict) -> nn.Module: + """Get the normal TTS model.""" + + fm_decoder = get_fm_decoder_model(params) + text_encoder = get_text_encoder_model(params) + + model = TtsModel( + fm_decoder=fm_decoder, + text_encoder=text_encoder, + text_embed_dim=params.text_embed_dim, + feat_dim=params.feat_dim, + vocab_size=params.vocab_size, + pad_id=params.pad_id, + ) + return model + + +def get_distill_model(params: AttributeDict) -> nn.Module: + """Get the distillation TTS model.""" + + fm_decoder = get_fm_decoder_model(params, distill=True) + text_encoder = get_text_encoder_model(params) + + model = DistillTTSModelTrainWrapper( + fm_decoder=fm_decoder, + text_encoder=text_encoder, + text_embed_dim=params.text_embed_dim, + feat_dim=params.feat_dim, + vocab_size=params.vocab_size, + pad_id=params.pad_id, + ) + return model + + +def get_fm_decoder_model(params: AttributeDict, distill: bool = False) -> nn.Module: + """Get the Zipformer-based FM decoder model.""" + + encoder = TTSZipformer( + in_dim=params.feat_dim * 3, + out_dim=params.feat_dim, + downsampling_factor=to_int_tuple(params.fm_decoder_downsampling_factor), + num_encoder_layers=to_int_tuple(params.fm_decoder_num_layers), + cnn_module_kernel=to_int_tuple(params.fm_decoder_cnn_module_kernel), + encoder_dim=params.fm_decoder_dim, + feedforward_dim=params.fm_decoder_feedforward_dim, + num_heads=params.fm_decoder_num_heads, + query_head_dim=params.query_head_dim, + pos_head_dim=params.pos_head_dim, + value_head_dim=params.value_head_dim, + pos_dim=params.pos_dim, + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + use_time_embed=True, + time_embed_dim=192, + use_guidance_scale_embed=distill, + ) + return encoder + + +def get_text_encoder_model(params: AttributeDict) -> nn.Module: + """Get the Zipformer-based text encoder model.""" + + encoder = TTSZipformer( + in_dim=params.text_embed_dim, + out_dim=params.feat_dim, + downsampling_factor=to_int_tuple(params.text_encoder_downsampling_factor), + num_encoder_layers=to_int_tuple(params.text_encoder_num_layers), + cnn_module_kernel=to_int_tuple(params.text_encoder_cnn_module_kernel), + encoder_dim=params.text_encoder_dim, + feedforward_dim=params.text_encoder_feedforward_dim, + num_heads=params.text_encoder_num_heads, + query_head_dim=params.query_head_dim, + pos_head_dim=params.pos_head_dim, + value_head_dim=params.value_head_dim, + pos_dim=params.pos_dim, + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + use_time_embed=False, + ) + return encoder + + +class TtsModel(nn.Module): + """The normal TTS model.""" + + def __init__( + self, + fm_decoder: nn.Module, + text_encoder: nn.Module, + text_embed_dim: int, + feat_dim: int, + vocab_size: int, + pad_id: int = 0, + ): + """ + Args: + fm_decoder: the flow-matching encoder model, inputs are the + input condition embeddings and noisy acoustic features, + outputs are better acoustic features. + text_encoder: the text encoder model. input are text + embeddings, output are contextualized text embeddings. + text_embed_dim: dimension of text embedding. + feat_dim: dimension of acoustic features. + vocab_size: vocabulary size. + pad_id: padding id. + """ + super().__init__() + + self.feat_dim = feat_dim + self.text_embed_dim = text_embed_dim + self.pad_id = pad_id + + self.fm_decoder = fm_decoder + + self.text_encoder = text_encoder + + self.embed = nn.Embedding(vocab_size, text_embed_dim) + + self.distill = False + + def forward_fm_decoder( + self, + t: torch.Tensor, + xt: torch.Tensor, + text_condition: torch.Tensor, + speech_condition: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + guidance_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Compute velocity. + Args: + t: A tensor of shape (N, 1, 1) or a tensor of a float, + in the range of (0, 1). + xt: the input of the current timestep, including condition + embeddings and noisy acoustic features. + text_condition: the text condition embeddings, with the + shape (batch, seq_len, emb_dim). + speech_condition: the speech condition embeddings, with the + shape (batch, seq_len, emb_dim). + padding_mask: The mask for padding, True means masked + position, with the shape (N, T). + guidance_scale: The guidance scale in classifier-free guidance, + which is a tensor of shape (N, 1, 1) or a tensor of a float. + + Returns: + predicted velocity, with the shape (batch, seq_len, emb_dim). + """ + assert t.dim() in (0, 3) + # Handle t with the shape (N, 1, 1): + # squeeze the last dimension if it's size is 1. + while t.dim() > 1 and t.size(-1) == 1: + t = t.squeeze(-1) + if guidance_scale is not None: + while guidance_scale.dim() > 1 and guidance_scale.size(-1) == 1: + guidance_scale = guidance_scale.squeeze(-1) + # Handle t with a single value: expand to the size of batch size. + if t.dim() == 0: + t = t.repeat(xt.shape[0]) + if guidance_scale is not None and guidance_scale.dim() == 0: + guidance_scale = guidance_scale.repeat(xt.shape[0]) + + xt = torch.cat([xt, text_condition, speech_condition], dim=2) + vt = self.fm_decoder( + x=xt, t=t, padding_mask=padding_mask, guidance_scale=guidance_scale + ) + return vt + + def forward_text_embed( + self, + tokens: List[List[int]], + ): + """ + Get the text embeddings. + Args: + tokens: a list of list of token ids. + Returns: + embed: the text embeddings, shape (batch, seq_len, emb_dim). + tokens_lens: the length of each token sequence, shape (batch,). + """ + device = ( + self.device if isinstance(self, DDP) else next(self.parameters()).device + ) + tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S) + embed = self.embed(tokens_padded) # (B, S, C) + tokens_lens = torch.tensor( + [len(token) for token in tokens], dtype=torch.int64, device=device + ) + tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S) + + embed = self.text_encoder( + x=embed, t=None, padding_mask=tokens_padding_mask + ) # (B, S, C) + return embed, tokens_lens + + def forward_text_condition( + self, + embed: torch.Tensor, + tokens_lens: torch.Tensor, + features_lens: torch.Tensor, + ): + """ + Get the text condition with the same length of the acoustic feature. + Args: + embed: the text embeddings, shape (batch, token_seq_len, emb_dim). + tokens_lens: the length of each token sequence, shape (batch,). + features_lens: the length of each acoustic feature sequence, + shape (batch,). + Returns: + text_condition: the text condition, shape + (batch, feature_seq_len, emb_dim). + padding_mask: the padding mask of text condition, shape + (batch, feature_seq_len). + """ + + num_frames = int(features_lens.max()) + + padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T) + + tokens_durations = prepare_avg_tokens_durations(features_lens, tokens_lens) + + tokens_index = get_tokens_index(tokens_durations, num_frames).to( + embed.device + ) # (B, T) + + text_condition = torch.gather( + embed, + dim=1, + index=tokens_index.unsqueeze(-1).expand( + embed.size(0), num_frames, embed.size(-1) + ), + ) # (B, T, F) + return text_condition, padding_mask + + def forward_text_train( + self, + tokens: List[List[int]], + features_lens: torch.Tensor, + ): + """ + Process text for training, given text tokens and real feature lengths. + """ + embed, tokens_lens = self.forward_text_embed(tokens) + text_condition, padding_mask = self.forward_text_condition( + embed, tokens_lens, features_lens + ) + return ( + text_condition, + padding_mask, + ) + + def forward_text_inference_gt_duration( + self, + tokens: List[List[int]], + features_lens: torch.Tensor, + prompt_tokens: List[List[int]], + prompt_features_lens: torch.Tensor, + ): + """ + Process text for inference, given text tokens, real feature lengths and prompts. + """ + tokens = [ + prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens) + ] + features_lens = prompt_features_lens + features_lens + embed, tokens_lens = self.forward_text_embed(tokens) + text_condition, padding_mask = self.forward_text_condition( + embed, tokens_lens, features_lens + ) + return text_condition, padding_mask + + def forward_text_inference_ratio_duration( + self, + tokens: List[List[int]], + prompt_tokens: List[List[int]], + prompt_features_lens: torch.Tensor, + speed: float, + ): + """ + Process text for inference, given text tokens and prompts, + feature lengths are predicted with the ratio of token numbers. + """ + device = ( + self.device if isinstance(self, DDP) else next(self.parameters()).device + ) + + cat_tokens = [ + prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens) + ] + + prompt_tokens_lens = torch.tensor( + [len(token) for token in prompt_tokens], dtype=torch.int64, device=device + ) + + cat_embed, cat_tokens_lens = self.forward_text_embed(cat_tokens) + + features_lens = torch.ceil( + (prompt_features_lens / prompt_tokens_lens * cat_tokens_lens / speed) + ).to(dtype=torch.int64) + + text_condition, padding_mask = self.forward_text_condition( + cat_embed, cat_tokens_lens, features_lens + ) + return text_condition, padding_mask + + def forward( + self, + tokens: List[List[int]], + features: torch.Tensor, + features_lens: torch.Tensor, + noise: torch.Tensor, + t: torch.Tensor, + condition_drop_ratio: float = 0.0, + ) -> torch.Tensor: + """Forward pass of the model for training. + Args: + tokens: a list of list of token ids. + features: the acoustic features, with the shape (batch, seq_len, feat_dim). + features_lens: the length of each acoustic feature sequence, shape (batch,). + noise: the intitial noise, with the shape (batch, seq_len, feat_dim). + t: the time step, with the shape (batch, 1, 1). + condition_drop_ratio: the ratio of dropped text condition. + Returns: + fm_loss: the flow-matching loss. + """ + + (text_condition, padding_mask,) = self.forward_text_train( + tokens=tokens, + features_lens=features_lens, + ) + + speech_condition_mask = condition_time_mask( + features_lens=features_lens, + mask_percent=(0.7, 1.0), + max_len=features.size(1), + ) + speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features) + + if condition_drop_ratio > 0.0: + drop_mask = ( + torch.rand(text_condition.size(0), 1, 1).to(text_condition.device) + > condition_drop_ratio + ) + text_condition = text_condition * drop_mask + + xt = features * t + noise * (1 - t) + ut = features - noise # (B, T, F) + + vt = self.forward_fm_decoder( + t=t, + xt=xt, + text_condition=text_condition, + speech_condition=speech_condition, + padding_mask=padding_mask, + ) + + loss_mask = speech_condition_mask & (~padding_mask) + fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2) + + return fm_loss + + def sample( + self, + tokens: List[List[int]], + prompt_tokens: List[List[int]], + prompt_features: torch.Tensor, + prompt_features_lens: torch.Tensor, + features_lens: Optional[torch.Tensor] = None, + speed: float = 1.0, + t_shift: float = 1.0, + duration: str = "predict", + num_step: int = 5, + guidance_scale: float = 0.5, + ) -> torch.Tensor: + """ + Generate acoustic features, given text tokens, prompts feature + and prompt transcription's text tokens. + Args: + tokens: a list of list of text tokens. + prompt_tokens: a list of list of prompt tokens. + prompt_features: the prompt feature with the shape + (batch_size, seq_len, feat_dim). + prompt_features_lens: the length of each prompt feature, + with the shape (batch_size,). + features_lens: the length of the predicted eature, with the + shape (batch_size,). It is used only when duration is "real". + duration: "real" or "predict". If "real", the predicted + feature length is given by features_lens. + num_step: the number of steps to use in the ODE solver. + guidance_scale: the guidance scale for classifier-free guidance. + distill: whether to use the distillation model for sampling. + """ + + assert duration in ["real", "predict"] + + if duration == "predict": + ( + text_condition, + padding_mask, + ) = self.forward_text_inference_ratio_duration( + tokens=tokens, + prompt_tokens=prompt_tokens, + prompt_features_lens=prompt_features_lens, + speed=speed, + ) + else: + assert features_lens is not None + text_condition, padding_mask = self.forward_text_inference_gt_duration( + tokens=tokens, + features_lens=features_lens, + prompt_tokens=prompt_tokens, + prompt_features_lens=prompt_features_lens, + ) + batch_size, num_frames, _ = text_condition.shape + + speech_condition = torch.nn.functional.pad( + prompt_features, (0, 0, 0, num_frames - prompt_features.size(1)) + ) # (B, T, F) + + # False means speech condition positions. + speech_condition_mask = make_pad_mask(prompt_features_lens, num_frames) + speech_condition = torch.where( + speech_condition_mask.unsqueeze(-1), + torch.zeros_like(speech_condition), + speech_condition, + ) + + x0 = torch.randn( + batch_size, num_frames, self.feat_dim, device=text_condition.device + ) + solver = EulerSolver(self, distill=self.distill, func_name="forward_fm_decoder") + + x1 = solver.sample( + x=x0, + text_condition=text_condition, + speech_condition=speech_condition, + padding_mask=padding_mask, + num_step=num_step, + guidance_scale=guidance_scale, + t_shift=t_shift, + ) + x1_wo_prompt_lens = (~padding_mask).sum(-1) - prompt_features_lens + x1_prompt = torch.zeros( + x1.size(0), prompt_features_lens.max(), x1.size(2), device=x1.device + ) + x1_wo_prompt = torch.zeros( + x1.size(0), x1_wo_prompt_lens.max(), x1.size(2), device=x1.device + ) + for i in range(x1.size(0)): + x1_wo_prompt[i, : x1_wo_prompt_lens[i], :] = x1[ + i, + prompt_features_lens[i] : prompt_features_lens[i] + + x1_wo_prompt_lens[i], + ] + x1_prompt[i, : prompt_features_lens[i], :] = x1[ + i, : prompt_features_lens[i] + ] + + return x1_wo_prompt, x1_wo_prompt_lens, x1_prompt, prompt_features_lens + + def sample_intermediate( + self, + tokens: List[List[int]], + features: torch.Tensor, + features_lens: torch.Tensor, + noise: torch.Tensor, + speech_condition_mask: torch.Tensor, + t_start: torch.Tensor, + t_end: torch.Tensor, + num_step: int = 1, + guidance_scale: torch.Tensor = None, + ) -> torch.Tensor: + """ + Generate acoustic features in intermediate timesteps. + Args: + tokens: List of list of token ids. + features: The acoustic features, with the shape (batch, seq_len, feat_dim). + features_lens: The length of each acoustic feature sequence, + with the shape (batch,). + noise: The initial noise, with the shape (batch, seq_len, feat_dim). + speech_condition_mask: The mask for speech condition, True means + non-condition positions, with the shape (batch, seq_len). + t_start: The start timestep, with the shape (batch, 1, 1). + t_end: The end timestep, with the shape (batch, 1, 1). + num_step: The number of steps for sampling. + guidance_scale: The scale for classifier-free guidance inference, + with the shape (batch, 1, 1). + distill: Whether to use distillation model. + """ + (text_condition, padding_mask,) = self.forward_text_train( + tokens=tokens, + features_lens=features_lens, + ) + + speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features) + + solver = EulerSolver(self, distill=self.distill, func_name="forward_fm_decoder") + + x_t_end = solver.sample( + x=noise, + text_condition=text_condition, + speech_condition=speech_condition, + padding_mask=padding_mask, + num_step=num_step, + guidance_scale=guidance_scale, + t_start=t_start, + t_end=t_end, + ) + x_t_end_lens = (~padding_mask).sum(-1) + return x_t_end, x_t_end_lens + + +class DistillTTSModelTrainWrapper(TtsModel): + """Wrapper for training the distilled TTS model.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.distill = True + + def forward( + self, + tokens: List[List[int]], + features: torch.Tensor, + features_lens: torch.Tensor, + noise: torch.Tensor, + speech_condition_mask: torch.Tensor, + t_start: torch.Tensor, + t_end: torch.Tensor, + num_step: int = 1, + guidance_scale: torch.Tensor = None, + ) -> torch.Tensor: + + return self.sample_intermediate( + tokens=tokens, + features=features, + features_lens=features_lens, + noise=noise, + speech_condition_mask=speech_condition_mask, + t_start=t_start, + t_end=t_end, + num_step=num_step, + guidance_scale=guidance_scale, + ) diff --git a/egs/zipvoice/zipvoice/optim.py b/egs/zipvoice/zipvoice/optim.py new file mode 100644 index 000000000..daf17556a --- /dev/null +++ b/egs/zipvoice/zipvoice/optim.py @@ -0,0 +1,1256 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import logging +import random +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import torch +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.optim import Optimizer + + +class BatchedOptimizer(Optimizer): + """ + This class adds to class Optimizer the capability to optimize parameters in batches: + it will stack the parameters and their grads for you so the optimizer can work + on tensors with an extra leading dimension. This is intended for speed with GPUs, + as it reduces the number of kernels launched in the optimizer. + + Args: + params: + """ + + def __init__(self, params, defaults): + super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager + def batched_params(self, param_group, group_params_names): + """ + This function returns (technically, yields) a list of + of tuples (p, state), where + p is a `fake` parameter that is stacked (over axis 0) from real parameters + that share the same shape, and its gradient is also stacked; + `state` is the state corresponding to this batch of parameters + (it will be physically located in the "state" for one of the real + parameters, the last one that has any particular shape and dtype). + + This function is decorated as a context manager so that it can + write parameters back to their "real" locations. + + The idea is, instead of doing: + + for p in group["params"]: + state = self.state[p] + ... + + you can do: + + with self.batched_params(group["params"]) as batches: + for p, state, p_names in batches: + ... + + + Args: + group: a parameter group, which is a list of parameters; should be + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. + """ + batches = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches_names = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str + + assert len(param_group) == len(group_params_names) + for p, named_p in zip(param_group, group_params_names): + key = (str(p.dtype), *p.shape) + batches[key].append(p) + batches_names[key].append(named_p) + + batches_names_keys = list(batches_names.keys()) + sorted_idx = sorted( + range(len(batches_names)), key=lambda i: batches_names_keys[i] + ) + batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] + batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] + + stacked_params_dict = dict() + + # turn batches into a list, in deterministic order. + # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # one for each batch in `batches`. + tuples = [] + + for batch, batch_names in zip(batches, batches_names): + p = batch[0] + # we arbitrarily store the state in the + # state corresponding to the 1st parameter in the + # group. class Optimizer will take care of saving/loading state. + state = self.state[p] + p_stacked = torch.stack(batch) + grad = torch.stack( + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + ) + p_stacked.grad = grad + stacked_params_dict[key] = p_stacked + tuples.append((p_stacked, state, batch_names)) + + yield tuples # <-- calling code will do the actual optimization here! + + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for i, p in enumerate(batch): # batch is list of Parameter + p.copy_(stacked_params[i]) + + +def basic_step(group, p, state, grad): + # computes basic Adam update using beta2 (dividing by gradient stddev) only. no momentum yet. + lr = group["lr"] + if p.numel() == p.shape[0]: + lr = lr * group["scalar_lr_scale"] + beta2 = group["betas"][1] + eps = group["eps"] + # p shape: (batch_size,) or (batch_size, 1, [1,..]) + try: + exp_avg_sq = state[ + "exp_avg_sq" + ] # shape: (batch_size,) or (batch_size, 1, [1,..]) + except KeyError: + exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + state["exp_avg_sq"] = exp_avg_sq + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # bias_correction2 is like in Adam. + # slower update at the start will help stability anyway. + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + denom = exp_avg_sq.sqrt().add_(eps) + + return -lr * grad / denom + + +def scaling_step(group, p, state, grad): + delta = basic_step(group, p, state, grad) + if p.numel() == p.shape[0]: + return delta # there is no scaling for scalar parameters. (p.shape[0] is the batch of parameters.) + + step = state["step"] + size_update_period = group["size_update_period"] + + try: + param_rms = state["param_rms"] + scale_grads = state["scale_grads"] + scale_exp_avg_sq = state["scale_exp_avg_sq"] + except KeyError: + # we know p.ndim > 1 because we'd have returned above if not, so don't worry + # about the speial case of dim=[] that pytorch treats inconsistently. + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + param_rms = param_rms.to(torch.float) + scale_exp_avg_sq = torch.zeros_like(param_rms) + scale_grads = torch.zeros( + size_update_period, *param_rms.shape, dtype=torch.float, device=p.device + ) + state["param_rms"] = param_rms + state["scale_grads"] = scale_grads + state["scale_exp_avg_sq"] = scale_exp_avg_sq + + # on every step, update the gradient w.r.t. the scale of the parameter, we + # store these as a batch and periodically update the size (for speed only, to + # avoid too many operations). + scale_grads[step % size_update_period] = (p * grad).sum( + dim=list(range(1, p.ndim)), keepdim=True + ) + + # periodically recompute the value of param_rms. + if step % size_update_period == size_update_period - 1: + param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()) + + param_min_rms = group["param_min_rms"] + + # scale the step size by param_rms. This is the most important "scaling" part of + # ScaledAdam + delta *= param_rms.clamp(min=param_min_rms) + + if step % size_update_period == size_update_period - 1 and step > 0: + # This block updates the size of parameter by adding a step ("delta") value in + # the direction of either shrinking or growing it. + beta2 = group["betas"][1] + size_lr = group["lr"] * group["scalar_lr_scale"] + param_max_rms = group["param_max_rms"] + eps = group["eps"] + batch_size = p.shape[0] + # correct beta2 for the size update period: we will have + # faster decay at this level. + beta2_corr = beta2**size_update_period + scale_exp_avg_sq.mul_(beta2_corr).add_( + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) + + # The 1st time we reach here is when size_step == 1. + size_step = (step + 1) // size_update_period + bias_correction2 = 1 - beta2_corr**size_step + + denom = scale_exp_avg_sq.sqrt() + eps + + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) + + is_too_small = param_rms < param_min_rms + + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + + # The following may help prevent instability: don't allow the scale step to be too large in + # either direction. + scale_step.clamp_(min=-0.1, max=0.1) + + # and ensure the parameter rms after update never exceeds param_max_rms. + # We have to look at the trained model for parameters at or around the + # param_max_rms, because sometimes they can indicate a problem with the + # topology or settings. + scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) + + delta.add_(p * scale_step) + + return delta + + +def momentum_step(group, p, state, grad): + delta = scaling_step(group, p, state, grad) + beta1 = group["betas"][0] + try: + stored_delta = state["delta"] + except KeyError: + stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + state["delta"] = stored_delta + stored_delta.mul_(beta1) + stored_delta.add_(delta, alpha=(1 - beta1)) + # we don't bother doing the "bias correction" part of Adam for beta1 because this is just + # an edge effect that affects the first 10 or so batches; and the effect of not doing it + # is just to do a slower update for the first few batches, which will help stability. + return stored_delta + + +class ScaledAdam(BatchedOptimizer): + """ + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) + + + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + Unlike common optimizers, which accept model.parameters() or groups of parameters(), + this optimizer could accept model.named_parameters() or groups of named_parameters(). + See comments of function _get_names_of_parameters for its 4 possible cases. + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period + """ + + def __init__( + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, + ): + + defaults = dict( + lr=lr, + clipping_scale=clipping_scale, + betas=betas, + scalar_lr_scale=scalar_lr_scale, + eps=eps, + param_min_rms=param_min_rms, + param_max_rms=param_max_rms, + scalar_max=scalar_max, + size_update_period=size_update_period, + clipping_update_period=clipping_update_period, + ) + + # If params only contains parameters or group of parameters, + # i.e when parameter names are not given, + # this flag will be set to False in funciton _get_names_of_parameters. + self.show_dominant_parameters = True + param_groups, parameters_names = self._get_names_of_parameters(params) + super(ScaledAdam, self).__init__(param_groups, defaults) + assert len(self.param_groups) == len(parameters_names) + self.parameters_names = parameters_names + + def _get_names_of_parameters( + self, params_or_named_params + ) -> Tuple[List[Dict], List[List[str]]]: + """ + Args: + params_or_named_params: according to the way ScaledAdam is initialized in train.py, + this argument could be one of following 4 cases, + case 1, a generator of parameter, e.g.: + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 2, a list of parameter groups with different config, e.g.: + model_param_groups = [ + {'params': model.encoder.parameters(), 'lr': 0.05}, + {'params': model.decoder.parameters(), 'lr': 0.01}, + {'params': model.joiner.parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) + + case 3, a generator of named_parameter, e.g.: + optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) + + case 4, a list of named_parameter groups with different config, e.g.: + model_named_param_groups = [ + {'named_params': model.encoder.named_parameters(), 'lr': 0.05}, + {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, + {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, + ] + optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) + + For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. + For case 3 and case 4, firstly, names and params are extracted from input named_params, + then, these extracted params are used to initialize the underlying torch.optimizer, + and these extracted names are mainly used by function + `_show_gradient_dominating_parameter` + + Returns: + Returns a tuple containing 2 elements: + - `param_groups` with type List[Dict], each Dict element is a parameter group. + An example of `param_groups` could be: + [ + {'params': `one iterable of Parameter`, 'lr': 0.05}, + {'params': `another iterable of Parameter`, 'lr': 0.08}, + {'params': `a third iterable of Parameter`, 'lr': 0.1}, + ] + - `param_gruops_names` with type List[List[str]], + each `List[str]` is for a group['params'] in param_groups, + and each `str` is the name of a parameter. + A dummy name "foo" is related to each parameter, + if input are params without names, i.e. case 1 or case 2. + """ + # variable naming convention in this function: + # p is short for param. + # np is short for named_param. + # p_or_np is short for param_or_named_param. + # cur is short for current. + # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. + # groups is a List[group] + + iterable_or_groups = list(params_or_named_params) + if len(iterable_or_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + + # The first value of returned tuple. A list of dicts containing at + # least 'params' as a key. + param_groups = [] + + # The second value of returned tuple, + # a List[List[str]], each sub-List is for a group. + param_groups_names = [] + + if not isinstance(iterable_or_groups[0], dict): + # case 1 or case 3, + # the input is an iterable of parameter or named parameter. + param_iterable_cur_group = [] + param_names_cur_group = [] + for p_or_np in iterable_or_groups: + if isinstance(p_or_np, tuple): + # case 3 + name, param = p_or_np + else: + # case 1 + assert isinstance(p_or_np, torch.Tensor) + param = p_or_np + # Assign a dummy name as a placeholder + name = "foo" + self.show_dominant_parameters = False + param_iterable_cur_group.append(param) + param_names_cur_group.append(name) + param_groups.append({"params": param_iterable_cur_group}) + param_groups_names.append(param_names_cur_group) + else: + # case 2 or case 4 + # the input is groups of parameter or named parameter. + for cur_group in iterable_or_groups: + if "named_params" in cur_group: + name_list = [x[0] for x in cur_group["named_params"]] + p_list = [x[1] for x in cur_group["named_params"]] + del cur_group["named_params"] + cur_group["params"] = p_list + else: + assert "params" in cur_group + name_list = ["foo" for _ in cur_group["params"]] + param_groups.append(cur_group) + param_groups_names.append(name_list) + + return param_groups, param_groups_names + + def __setstate__(self, state): + super(ScaledAdam, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + batch = True + + for group, group_params_names in zip(self.param_groups, self.parameters_names): + + with self.batched_params(group["params"], group_params_names) as batches: + + # batches is list of pairs (stacked_param, state). stacked_param is like + # a regular parameter, and will have a .grad, but the 1st dim corresponds to + # a stacking dim, it is not a real dim. + + if ( + len(batches[0][1]) == 0 + ): # if len(first state) == 0: not yet initialized + clipping_scale = 1 + else: + clipping_scale = self._get_clipping_scale(group, batches) + + for p, state, _ in batches: + # Perform optimization step. + # grad is not going to be None, we handled that when creating the batches. + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + + try: + cur_step = state["step"] + except KeyError: + state["step"] = 0 + cur_step = 0 + + grad = ( + p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale) + ) + p += momentum_step(group, p.detach(), state, grad) + + if p.numel() == p.shape[0]: # scalar parameter + scalar_max = group["scalar_max"] + p.clamp_(min=-scalar_max, max=scalar_max) + + state["step"] = cur_step + 1 + + return loss + + def _get_clipping_scale( + self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] + ) -> float: + """ + Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients + by this amount before applying the rest of the update. + + Args: + group: the parameter group, an item in self.param_groups + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + """ + assert len(tuples) >= 1 + clipping_scale = group["clipping_scale"] + (first_p, first_state, _) = tuples[0] + step = first_state["step"] + if clipping_scale is None or step == 0: + # no clipping. return early on step == 0 because the other + # parameters' state won't have been initialized yet. + return 1.0 + clipping_update_period = group["clipping_update_period"] + scalar_lr_scale = group["scalar_lr_scale"] + + tot_sumsq = torch.tensor(0.0, device=first_p.device) + for (p, state, param_names) in tuples: + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + "ScaledAdam optimizer does not support sparse gradients" + ) + if p.numel() == p.shape[0]: # a batch of scalars + tot_sumsq += (grad**2).sum() * ( + scalar_lr_scale**2 + ) # sum() to change shape [1] to [] + else: + tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() + + tot_norm = tot_sumsq.sqrt() + if "model_norms" not in first_state: + first_state["model_norms"] = torch.zeros( + clipping_update_period, device=p.device + ) + first_state["model_norms"][step % clipping_update_period] = tot_norm + + irregular_estimate_steps = [ + i for i in [10, 20, 40] if i < clipping_update_period + ] + if step % clipping_update_period == 0 or step in irregular_estimate_steps: + # Print some stats. + # We don't reach here if step == 0 because we would have returned + # above. + sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + if step in irregular_estimate_steps: + sorted_norms = sorted_norms[-step:] + num_norms = sorted_norms.numel() + quartiles = [] + for n in range(0, 5): + index = min(num_norms - 1, (num_norms // 4) * n) + quartiles.append(sorted_norms[index].item()) + + median = quartiles[2] + if median - median != 0: + raise RuntimeError("Too many grads were not finite") + threshold = clipping_scale * median + if step in irregular_estimate_steps: + # use larger thresholds on first few steps of estimating threshold, + # as norm may be changing rapidly. + threshold = threshold * 2.0 + first_state["model_norm_threshold"] = threshold + percent_clipped = ( + first_state["num_clipped"] * 100.0 / num_norms + if "num_clipped" in first_state + else 0.0 + ) + first_state["num_clipped"] = 0 + quartiles = " ".join(["%.3e" % x for x in quartiles]) + logging.warning( + f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" + ) + + try: + model_norm_threshold = first_state["model_norm_threshold"] + except KeyError: + return 1.0 # threshold has not yet been set. + + ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + if ans != ans: # e.g. ans is nan + ans = 0.0 + if ans < 1.0: + first_state["num_clipped"] += 1 + if ans < 0.5: + logging.warning( + f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" + ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter( + tuples, tot_sumsq, group["scalar_lr_scale"] + ) + self._show_param_with_unusual_grad(tuples) + + if ans == 0.0: + for (p, state, param_names) in tuples: + p.grad.zero_() # get rid of infinity() + + return ans + + def _show_param_with_unusual_grad( + self, + tuples: List[Tuple[Tensor, dict, List[str]]], + ): + """ + Print information about parameter which has the largest ratio of grad-on-this-batch + divided by normal grad size. + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + """ + largest_ratio = 0.0 + largest_name = "" + # ratios_names is a list of 3-tuples: (grad_ratio, param_name, tensor) + ratios_names = [] + for (p, state, batch_param_names) in tuples: + dims = list(range(1, p.ndim)) + + def mean(x): + # workaround for bad interface of torch's "mean" for when dims is the empty list. + if len(dims) > 0: + return x.mean(dim=dims) + else: + return x + + grad_ratio = ( + (mean(p.grad**2) / state["exp_avg_sq"].mean(dim=dims)) + .sqrt() + .to("cpu") + ) + + ratios_names += zip( + grad_ratio.tolist(), batch_param_names, p.grad.unbind(dim=0) + ) + + ratios_names = sorted(ratios_names, reverse=True) + ratios_names = ratios_names[:10] + ratios_names = [ + (ratio, name, largest_index(tensor)) + for (ratio, name, tensor) in ratios_names + ] + + logging.debug( + f"Parameters with most larger-than-usual grads, with ratios, are: {ratios_names}" + ) + + def _show_gradient_dominating_parameter( + self, + tuples: List[Tuple[Tensor, dict, List[str]]], + tot_sumsq: Tensor, + scalar_lr_scale: float, + ): + """ + Show information of parameter which dominates tot_sumsq. + + Args: + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + tot_sumsq: sumsq of all parameters. Though it's could be calculated + from tuples, we still pass it to save some time. + """ + all_sumsq_orig = {} + for (p, state, batch_param_names) in tuples: + # p is a stacked batch parameters. + batch_grad = p.grad + if p.numel() == p.shape[0]: # a batch of scalars + # Dummy values used by following `zip` statement. + batch_rms_orig = torch.full( + p.shape, scalar_lr_scale, device=batch_grad.device + ) + else: + batch_rms_orig = state["param_rms"] + batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2 + if batch_grad.ndim > 1: + # need to guard it with if-statement because sum() sums over + # all dims if dim == (). + batch_sumsq_orig = batch_sumsq_orig.sum( + dim=list(range(1, batch_grad.ndim)) + ) + for name, sumsq_orig, rms, grad in zip( + batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad + ): + + proportion_orig = sumsq_orig / tot_sumsq + all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) + + sorted_by_proportion = { + k: v + for k, v in sorted( + all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True + ) + } + dominant_param_name = next(iter(sorted_by_proportion)) + ( + dominant_proportion, + dominant_sumsq, + dominant_rms, + dominant_grad, + ) = sorted_by_proportion[dominant_param_name] + logging.debug( + f"Parameter dominating tot_sumsq {dominant_param_name}" + f" with proportion {dominant_proportion:.2f}," + f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" + f"={dominant_sumsq:.3e}," + f" grad_sumsq={(dominant_grad**2).sum():.3e}," + f" orig_rms_sq={(dominant_rms**2).item():.3e}" + ) + + +def largest_index(x: Tensor): + x = x.contiguous() + argmax = x.abs().argmax().item() + return [(argmax // x.stride(i)) % x.size(i) for i in range(x.ndim)] + + +class LRScheduler(object): + """ + Base-class for learning rate schedulers where the learning-rate depends on both the + batch and the epoch. + """ + + def __init__(self, optimizer: Optimizer, verbose: bool = False): + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + self.verbose = verbose + + for group in optimizer.param_groups: + group.setdefault("base_lr", group["lr"]) + + self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + + self.epoch = 0 + self.batch = 0 + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return { + # the user might try to override the base_lr, so don't include this in the state. + # previously they were included. + # "base_lrs": self.base_lrs, + "epoch": self.epoch, + "batch": self.batch, + } + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # the things with base_lrs are a work-around for a previous problem + # where base_lrs were written with the state dict. + base_lrs = self.base_lrs + self.__dict__.update(state_dict) + self.base_lrs = base_lrs + + def get_last_lr(self) -> List[float]: + """Return last computed learning rate by current scheduler. Will be a list of float.""" + return self._last_lr + + def get_lr(self): + # Compute list of learning rates from self.epoch and self.batch and + # self.base_lrs; this must be overloaded by the user. + # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] + raise NotImplementedError + + def step_batch(self, batch: Optional[int] = None) -> None: + # Step the batch index, or just set it. If `batch` is specified, it + # must be the batch index from the start of training, i.e. summed over + # all epochs. + # You can call this in any order; if you don't provide 'batch', it should + # of course be called once per batch. + if batch is not None: + self.batch = batch + else: + self.batch = self.batch + 1 + self._set_lrs() + + def step_epoch(self, epoch: Optional[int] = None): + # Step the epoch index, or just set it. If you provide the 'epoch' arg, + # you should call this at the start of the epoch; if you don't provide the 'epoch' + # arg, you should call it at the end of the epoch. + if epoch is not None: + self.epoch = epoch + else: + self.epoch = self.epoch + 1 + self._set_lrs() + + def _set_lrs(self): + values = self.get_lr() + assert len(values) == len(self.optimizer.param_groups) + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] + + def print_lr(self, is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logging.warning( + f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" + f" of group {group} to {lr:.4e}." + ) + + +class Eden(LRScheduler): + """ + Eden scheduler. + The basic formula (before warmup) is: + lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * + (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup + where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches + and then stays constant at 1. + + If you don't have the concept of epochs, or one epoch takes a very long time, + you can replace the notion of 'epoch' with some measure of the amount of data + processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to + some measure representing "quite a lot of data": say, one fifth or one third + of an entire training run, but it doesn't matter much. You could also use + Eden2 which has only the notion of batches. + + We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + lr_epochs: the number of epochs after which we start significantly + decreasing the learning rate, suggest 6 if you plan to do e.g. + 20 to 40 epochs, but may need smaller number if dataset is huge + and you will do few epochs. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + lr_epochs: Union[int, float], + warmup_batches: Union[int, float] = 500.0, + warmup_start: float = 0.5, + verbose: bool = False, + ): + super(Eden, self).__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.lr_epochs = lr_epochs + self.warmup_batches = warmup_batches + + assert 0.0 <= warmup_start <= 1.0, warmup_start + self.warmup_start = warmup_start + + def get_lr(self): + factor = ( + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + ) ** -0.25 * ( + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ) + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else self.warmup_start + + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + # else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ) + + return [x * factor * warmup_factor for x in self.base_lrs] + + +class Eden2(LRScheduler): + """ + Eden2 scheduler, simpler than Eden because it does not use the notion of epoch, + only batches. + + The basic formula (before warmup) is: + lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup + + where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches + and then stays constant at 1. + + + E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam + + Args: + optimizer: the optimizer to change the learning rates on + lr_batches: the number of batches after which we start significantly + decreasing the learning rate, suggest 5000. + """ + + def __init__( + self, + optimizer: Optimizer, + lr_batches: Union[int, float], + warmup_batches: Union[int, float] = 500.0, + warmup_start: float = 0.5, + verbose: bool = False, + ): + super().__init__(optimizer, verbose) + self.lr_batches = lr_batches + self.warmup_batches = warmup_batches + + assert 0.0 <= warmup_start <= 1.0, warmup_start + self.warmup_start = warmup_start + + def get_lr(self): + factor = ( + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + ) ** -0.5 + warmup_factor = ( + 1.0 + if self.batch >= self.warmup_batches + else self.warmup_start + + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) + # else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ) + + return [x * factor * warmup_factor for x in self.base_lrs] + + +class FixedLRScheduler(LRScheduler): + """ + Fixed learning rate scheduler. + + Args: + optimizer: the optimizer to change the learning rates on + """ + + def __init__( + self, + optimizer: Optimizer, + verbose: bool = False, + ): + super(FixedLRScheduler, self).__init__(optimizer, verbose) + + def get_lr(self): + + return [x for x in self.base_lrs] + + +def _test_eden(): + m = torch.nn.Linear(100, 100) + optim = ScaledAdam(m.parameters(), lr=0.03) + + scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(20): + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + + logging.info(f"last lr = {scheduler.get_last_lr()}") + logging.info(f"state dict = {scheduler.state_dict()}") + + +# This is included mostly as a baseline for ScaledAdam. +class Eve(Optimizer): + """ + Implements Eve algorithm. This is a modified version of AdamW with a special + way of setting the weight-decay / shrinkage-factor, which is designed to make the + rms of the parameters approach a particular target_rms (default: 0.1). This is + for use with networks with 'scaled' versions of modules (see scaling.py), which + will be close to invariant to the absolute scale on the parameter matrix. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Eve is unpublished so far. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 3e-4; + this value means that the weight would decay significantly after + about 3k minibatches. Is not multiplied by learning rate, but + is conditional on RMS-value of parameter being > target_rms. + target_rms (float, optional): target root-mean-square value of + parameters, if they fall below this we will stop applying weight decay. + + + .. _Adam: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.98), + eps=1e-8, + weight_decay=1e-3, + target_rms=0.1, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0 <= weight_decay <= 0.1: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0 < target_rms <= 10.0: + raise ValueError("Invalid target_rms value: {}".format(target_rms)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + target_rms=target_rms, + ) + super(Eve, self).__init__(params, defaults) + + def __setstate__(self, state): + super(Eve, self).__setstate__(state) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + # Perform optimization step + grad = p.grad + if grad.is_sparse: + raise RuntimeError("AdamW does not support sparse gradients") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + beta1, beta2 = group["betas"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + group["eps"] + ) + + step_size = group["lr"] / bias_correction1 + target_rms = group["target_rms"] + weight_decay = group["weight_decay"] + + if p.numel() > 1: + # avoid applying this weight-decay on "scaling factors" + # (which are scalar). + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + p.mul_(1 - (weight_decay * is_above_target_rms)) + + p.addcdiv_(exp_avg, denom, value=-step_size) + + if random.random() < 0.0005: + step = (exp_avg / denom) * step_size + logging.info( + f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" + ) + + return loss + + +def _test_scaled_adam(hidden_dim: int): + import timeit + + from scaling import ScaledLinear + + E = 100 + B = 4 + T = 2 + logging.info("in test_eve_cain") + # device = torch.device('cuda') + device = torch.device("cpu") + dtype = torch.float32 + + fix_random_seed(42) + # these input_magnitudes and output_magnitudes are to test that + # Abel is working as we expect and is able to adjust scales of + # different dims differently. + input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() + + for iter in [1, 0]: + fix_random_seed(42) + Linear = torch.nn.Linear if iter == 0 else ScaledLinear + + m = torch.nn.Sequential( + Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) + + train_pairs = [ + ( + 100.0 + * torch.randn(B, T, E, device=device, dtype=dtype) + * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + ) + for _ in range(20) + ] + + if iter == 0: + optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: + optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0) + scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + + start = timeit.default_timer() + avg_loss = 0.0 + for epoch in range(180): + scheduler.step_epoch() + # if epoch == 100 and iter in [2,3]: + # optim.reset_speedup() # check it doesn't crash. + + # if epoch == 130: + # opts = diagnostics.TensorDiagnosticOptions( + # 512 + # ) # allow 4 megabytes per sub-module + # diagnostic = diagnostics.attach_diagnostics(m, opts) + + for n, (x, y) in enumerate(train_pairs): + y_out = m(x) + loss = ((y_out - y) ** 2).mean() * 100.0 + if epoch == 0 and n == 0: + avg_loss = loss.item() + else: + avg_loss = 0.98 * avg_loss + 0.02 * loss.item() + if n == 0 and epoch % 5 == 0: + # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + lr = scheduler.get_last_lr()[0] + logging.info( + f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" + ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + loss.log().backward() + optim.step() + optim.zero_grad() + scheduler.step_batch() + + # diagnostic.print_diagnostics() + + stop = timeit.default_timer() + logging.info(f"Iter={iter}, Time taken: {stop - start}") + + logging.info(f"last lr = {scheduler.get_last_lr()}") + # logging.info("state dict = ", scheduler.state_dict()) + # logging.info("optim state_dict = ", optim.state_dict()) + logging.info(f"input_magnitudes = {input_magnitudes}") + logging.info(f"output_magnitudes = {output_magnitudes}") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + logging.getLogger().setLevel(logging.INFO) + import subprocess + + s = subprocess.check_output( + "git status -uno .; git log -1; git diff HEAD .", shell=True + ) + logging.info(s) + import sys + + if len(sys.argv) > 1: + hidden_dim = int(sys.argv[1]) + else: + hidden_dim = 200 + + _test_scaled_adam(hidden_dim) + _test_eden() diff --git a/egs/zipvoice/zipvoice/scaling.py b/egs/zipvoice/zipvoice/scaling.py new file mode 100644 index 000000000..5211e3a76 --- /dev/null +++ b/egs/zipvoice/zipvoice/scaling.py @@ -0,0 +1,1910 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import math +import random +from typing import Optional, Tuple, Union + +import k2 +import torch +import torch.nn as nn +from torch import Tensor + +custom_bwd = lambda func: torch.amp.custom_bwd(func, device_type="cuda") +custom_fwd = lambda func: torch.amp.custom_fwd(func, device_type="cuda") + + +def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: + max_value = torch.max(x, y) + diff = torch.abs(x - y) + return max_value + torch.log1p(torch.exp(-diff)) + + +# RuntimeError: Exporting the operator logaddexp to ONNX opset version +# 14 is not supported. Please feel free to request support or submit +# a pull request on PyTorch GitHub. +# +# The following function is to solve the above error when exporting +# models to ONNX via torch.jit.trace() +def logaddexp(x: Tensor, y: Tensor) -> Tensor: + # Caution(fangjun): Put torch.jit.is_scripting() before + # torch.onnx.is_in_onnx_export(); + # otherwise, it will cause errors for torch.jit.script(). + # + # torch.logaddexp() works for both torch.jit.script() and + # torch.jit.trace() but it causes errors for ONNX export. + # + if torch.jit.is_scripting(): + # Note: We cannot use torch.jit.is_tracing() here as it also + # matches torch.onnx.export(). + return torch.logaddexp(x, y) + elif torch.onnx.is_in_onnx_export(): + return logaddexp_onnx(x, y) + else: + # for torch.jit.trace() + return torch.logaddexp(x, y) + + +class PiecewiseLinear(object): + """ + Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with + the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] + respectively. + """ + + def __init__(self, *args): + assert len(args) >= 1, len(args) + if len(args) == 1 and isinstance(args[0], PiecewiseLinear): + self.pairs = list(args[0].pairs) + else: + self.pairs = [(float(x), float(y)) for x, y in args] + for x, y in self.pairs: + assert isinstance(x, (float, int)), type(x) + assert isinstance(y, (float, int)), type(y) + + for i in range(len(self.pairs) - 1): + assert self.pairs[i + 1][0] > self.pairs[i][0], ( + i, + self.pairs[i], + self.pairs[i + 1], + ) + + def __str__(self): + # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' + return f"PiecewiseLinear({str(self.pairs)[1:-1]})" + + def __call__(self, x): + if x <= self.pairs[0][0]: + return self.pairs[0][1] + elif x >= self.pairs[-1][0]: + return self.pairs[-1][1] + else: + cur_x, cur_y = self.pairs[0] + for i in range(1, len(self.pairs)): + next_x, next_y = self.pairs[i] + if x >= cur_x and x <= next_x: + return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) + cur_x, cur_y = next_x, next_y + assert False + + def __mul__(self, alpha): + return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) + + def __add__(self, x): + if isinstance(x, (float, int)): + return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) + s, x = self.get_common_basis(x) + return PiecewiseLinear( + *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] + ) + + def max(self, x): + if isinstance(x, (float, int)): + x = PiecewiseLinear((0, x)) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) + + def min(self, x): + if isinstance(x, float) or isinstance(x, int): + x = PiecewiseLinear((0, x)) + s, x = self.get_common_basis(x, include_crossings=True) + return PiecewiseLinear( + *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) + + def __eq__(self, other): + return self.pairs == other.pairs + + def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): + """ + Returns (self_mod, p_mod) which are equivalent piecewise linear + functions to self and p, but with the same x values. + + p: the other piecewise linear function + include_crossings: if true, include in the x values positions + where the functions indicate by this and p crosss. + """ + assert isinstance(p, PiecewiseLinear), type(p) + + # get sorted x-values without repetition. + x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + + if include_crossings: + extra_x_vals = [] + for i in range(len(x_vals) - 1): + if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): + # if the two lines in this subsegment potentially cross each other.. + diff_cur = abs(y_vals1[i] - y_vals2[i]) + diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) + # `pos`, between 0 and 1, gives the relative x position, + # with 0 being x_vals[i] and 1 being x_vals[i+1]. + pos = diff_cur / (diff_cur + diff_next) + extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) + extra_x_vals.append(extra_x_val) + if len(extra_x_vals) > 0: + x_vals = sorted(set(x_vals + extra_x_vals)) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + return ( + PiecewiseLinear(*zip(x_vals, y_vals1)), + PiecewiseLinear(*zip(x_vals, y_vals2)), + ) + + +class ScheduledFloat(torch.nn.Module): + """ + This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); + it does not have a working forward() function. You are supposed to cast it to float, as + in, float(parent_module.whatever), and use it as something like a dropout prob. + + It is a floating point value whose value changes depending on the batch count of the + training loop. It is a piecewise linear function where you specify the (x,y) pairs + in sorted order on x; x corresponds to the batch index. For batch-index values before the + first x or after the last x, we just use the first or last y value. + + Example: + self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) + + `default` is used when self.batch_count is not set or not in training mode or in + torch.jit scripting mode. + """ + + def __init__(self, *args, default: float = 0.0): + super().__init__() + # self.batch_count and self.name will be written to in the training loop. + self.batch_count = None + self.name = None + self.default = default + self.schedule = PiecewiseLinear(*args) + + def extra_repr(self) -> str: + return ( + f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" + ) + + def __float__(self): + batch_count = self.batch_count + if ( + batch_count is None + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): + return float(self.default) + else: + ans = self.schedule(self.batch_count) + if random.random() < 0.0002: + logging.debug( + f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}" + ) + return ans + + def __add__(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule + x, default=self.default) + else: + return ScheduledFloat( + self.schedule + x.schedule, default=self.default + x.default + ) + + def max(self, x): + if isinstance(x, float) or isinstance(x, int): + return ScheduledFloat(self.schedule.max(x), default=self.default) + else: + return ScheduledFloat( + self.schedule.max(x.schedule), default=max(self.default, x.default) + ) + + +FloatLike = Union[float, ScheduledFloat] + + +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: + """ + A randomized way of casting a floating point value to half precision. + """ + if x.dtype == torch.float16: + return x + x_abs = x.abs() + is_too_small = x_abs < min_abs + # for elements where is_too_small is true, random_val will contain +-min_abs with + # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, + # for those elements]. + random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) + return torch.where(is_too_small, random_val, x).to(torch.float16) + + +class CutoffEstimator: + """ + Estimates cutoffs of an arbitrary numerical quantity such that a specified + proportion of items will be above the cutoff on average. + + p is the proportion of items that should be above the cutoff. + """ + + def __init__(self, p: float): + self.p = p + # total count of items + self.count = 0 + # total count of items that were above the cutoff + self.count_above = 0 + # initial cutoff value + self.cutoff = 0 + + def __call__(self, x: float) -> bool: + """ + Returns true if x is above the cutoff. + """ + ans = x > self.cutoff + self.count += 1 + if ans: + self.count_above += 1 + cur_p = self.count_above / self.count + delta_p = cur_p - self.p + if (delta_p > 0) == ans: + q = abs(delta_p) + self.cutoff = x * q + self.cutoff * (1 - q) + return ans + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.float16) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + (ans,) = ctx.saved_tensors + with torch.amp.autocast("cuda", enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, dim: int): + if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing(): + return x.softmax(dim=dim) + + return SoftmaxFunction.apply(x, dim) + + +class MaxEigLimiterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: + ctx.channel_dim = channel_dim + ctx.grad_scale = grad_scale + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) + return x + + @staticmethod + def backward(ctx, x_grad, *args): + with torch.enable_grad(): + (x_orig, coeffs, new_direction) = ctx.saved_tensors + x_orig.requires_grad = True + num_channels = x_orig.shape[ctx.channel_dim] + x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) + new_direction.requires_grad = False + x = x - x.mean(dim=0) + x_var = (x**2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual**2).mean() + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. This is to be minimized. + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + variance_proportion.backward() + x_orig_grad = x_orig.grad + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) + return x_grad + x_extra_grad.detach(), None, None, None, None + + +class BiasNormFunction(torch.autograd.Function): + # This computes: + # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() + # return x * scales + # (after unsqueezing the bias), but it does it in a memory-efficient way so that + # it can just store the returned value (chances are, this will also be needed for + # some other reason, related to the next operation, so we can save memory). + @staticmethod + def forward( + ctx, + x: Tensor, + bias: Tensor, + log_scale: Tensor, + channel_dim: int, + store_output_for_backprop: bool, + ) -> Tensor: + assert bias.ndim == 1 + if channel_dim < 0: + channel_dim = channel_dim + x.ndim + ctx.store_output_for_backprop = store_output_for_backprop + ctx.channel_dim = channel_dim + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() + ans = x * scales + ctx.save_for_backward( + ans.detach() if store_output_for_backprop else x, + scales.detach(), + bias.detach(), + log_scale.detach(), + ) + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tensor: + ans_or_x, scales, bias, log_scale = ctx.saved_tensors + if ctx.store_output_for_backprop: + x = ans_or_x / scales + else: + x = ans_or_x + x = x.detach() + x.requires_grad = True + bias.requires_grad = True + log_scale.requires_grad = True + with torch.enable_grad(): + # recompute scales from x, bias and log_scale. + scales = ( + torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() + ans = x * scales + ans.backward(gradient=ans_grad) + return x.grad, bias.grad.flatten(), log_scale.grad, None, None + + +class BiasNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + Instead, we give the BiasNorm a trainable bias that it can use when + computing the scale for normalization. We also give it a (scalar) + trainable scale on the output. + + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interpreted as an offset from the input's ndim if negative. + This is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + log_scale: the initial log-scale that we multiply the output by; this + is learnable. + log_scale_min: FloatLike, minimum allowed value of log_scale + log_scale_max: FloatLike, maximum allowed value of log_scale + store_output_for_backprop: only possibly affects memory use; recommend + to set to True if you think the output of this module is more likely + than the input of this module to be required to be stored for the + backprop. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + log_scale: float = 1.0, + log_scale_min: float = -1.5, + log_scale_max: float = 1.5, + store_output_for_backprop: bool = False, + ) -> None: + super(BiasNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.log_scale = nn.Parameter(torch.tensor(log_scale)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + + self.log_scale_min = log_scale_min + self.log_scale_max = log_scale_max + + self.store_output_for_backprop = store_output_for_backprop + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + channel_dim = self.channel_dim + if channel_dim < 0: + channel_dim += x.ndim + bias = self.bias + for _ in range(channel_dim + 1, x.ndim): + bias = bias.unsqueeze(-1) + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * self.log_scale.exp() + return x * scales + + log_scale = limit_param_value( + self.log_scale, + min=float(self.log_scale_min), + max=float(self.log_scale_max), + training=self.training, + ) + + return BiasNormFunction.apply( + x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop + ) + + +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: + """ + Behaves like a constructor of a modified version of nn.Conv1d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv1d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: + """ + Behaves like a constructor of a modified version of nn.Conv2d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False, but: + NO PADDING-RELATED ARGS. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv2d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + return ans + + +class ChunkCausalDepthwiseConv1d(torch.nn.Module): + """ + Behaves like a depthwise 1d convolution, except that it is causal in + a chunkwise way, as if we had a block-triangular attention mask. + The chunk size is provided at test time (it should probably be + kept in sync with the attention mask). + + This has a little more than twice the parameters of a conventional + depthwise conv1d module: we implement it by having one + depthwise convolution, of half the width, that is causal (via + right-padding); and one depthwise convolution that is applied only + within chunks, that we multiply by a scaling factor which depends + on the position within the chunk. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + + def __init__( + self, + channels: int, + kernel_size: int, + initial_scale: float = 1.0, + bias: bool = True, + ): + super().__init__() + assert kernel_size % 2 == 1 + + half_kernel_size = (kernel_size + 1) // 2 + # will pad manually, on one side. + self.causal_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=half_kernel_size, + padding=0, + bias=True, + ) + + self.chunkwise_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=bias, + ) + + # first row is correction factors added to the scale near the left edge of the chunk, + # second row is correction factors added to the scale near the right edge of the chunk, + # both of these are added to a default scale of 1.0. + self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) + self.kernel_size = kernel_size + + with torch.no_grad(): + self.causal_conv.weight[:] *= initial_scale + self.chunkwise_conv.weight[:] *= initial_scale + if bias: + torch.nn.init.uniform_( + self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) + + def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: + """ + Forward function. Args: + x: a Tensor of shape (batch_size, channels, seq_len) + chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. + """ + (batch_size, num_channels, seq_len) = x.shape + + # half_kernel_size = self.kernel_size + 1 // 2 + # left_pad is half_kernel_size - 1 where half_kernel_size is the size used + # in the causal conv. It's the amount by which we must pad on the left, + # to make the convolution causal. + left_pad = self.kernel_size // 2 + + if chunk_size < 0 or chunk_size > seq_len: + chunk_size = seq_len + right_pad = -seq_len % chunk_size + + x = torch.nn.functional.pad(x, (left_pad, right_pad)) + + x_causal = self.causal_conv(x[..., : left_pad + seq_len]) + assert x_causal.shape == (batch_size, num_channels, seq_len) + + x_chunk = x[..., left_pad:] + num_chunks = x_chunk.shape[2] // chunk_size + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) + x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( + batch_size * num_chunks, num_channels, chunk_size + ) + x_chunk = self.chunkwise_conv(x_chunk) # does not change shape + + chunk_scale = self._get_chunk_scale(chunk_size) + + x_chunk = x_chunk * chunk_scale + x_chunk = x_chunk.reshape( + batch_size, num_chunks, num_channels, chunk_size + ).permute(0, 2, 1, 3) + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ + ..., :seq_len + ] + + return x_chunk + x_causal + + def _get_chunk_scale(self, chunk_size: int): + """Returns tensor of shape (num_channels, chunk_size) that will be used to + scale the output of self.chunkwise_conv.""" + left_edge = self.chunkwise_conv_scale[0] + right_edge = self.chunkwise_conv_scale[1] + if chunk_size < self.kernel_size: + left_edge = left_edge[:, :chunk_size] + right_edge = right_edge[:, -chunk_size:] + else: + t = chunk_size - self.kernel_size + channels = left_edge.shape[0] + pad = torch.zeros( + channels, t, device=left_edge.device, dtype=left_edge.dtype + ) + left_edge = torch.cat((left_edge, pad), dim=-1) + right_edge = torch.cat((pad, right_edge), dim=-1) + return 1.0 + (left_edge + right_edge) + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Streaming Forward function. + + Args: + x: a Tensor of shape (batch_size, channels, seq_len) + cache: cached left context of shape (batch_size, channels, left_pad) + """ + (batch_size, num_channels, seq_len) = x.shape + + # left_pad is half_kernel_size - 1 where half_kernel_size is the size used + # in the causal conv. It's the amount by which we must pad on the left, + # to make the convolution causal. + left_pad = self.kernel_size // 2 + + # Pad cache + assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad) + x = torch.cat([cache, x], dim=2) + # Update cache + cache = x[..., -left_pad:] + + x_causal = self.causal_conv(x) + assert x_causal.shape == (batch_size, num_channels, seq_len) + + x_chunk = x[..., left_pad:] + x_chunk = self.chunkwise_conv(x_chunk) # does not change shape + + chunk_scale = self._get_chunk_scale(chunk_size=seq_len) + x_chunk = x_chunk * chunk_scale + + return x_chunk + x_causal, cache + + +class BalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + min_mean: float, + max_mean: float, + min_rms: float, + max_rms: float, + grad_scale: float, + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + ctx.save_for_backward(x) + ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + (x,) = ctx.saved_tensors + (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config + + try: + with torch.enable_grad(): + with torch.amp.autocast("cuda", enabled=False): + x = x.to(torch.float32) + x = x.detach() + x.requires_grad = True + mean_dims = [i for i in range(x.ndim) if i != channel_dim] + uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) + mean = x.mean(dim=mean_dims, keepdim=True) + stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() + rms = uncentered_var.clamp(min=1.0e-20).sqrt() + + m = mean / stddev + # part of loss that relates to mean / stddev + m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() + + # put a much larger scale on the RMS-max-limit loss, so that if both it and the + # m_loss are violated we fix the RMS loss first. + rms_clamped = rms.clamp(min=min_rms, max=max_rms) + r_loss = (rms_clamped / rms).log().abs() + + loss = m_loss + r_loss + + loss.backward(gradient=torch.ones_like(loss)) + loss_grad = x.grad + loss_grad_rms = ( + (loss_grad**2) + .mean(dim=mean_dims, keepdim=True) + .sqrt() + .clamp(min=1.0e-20) + ) + + loss_grad = loss_grad * (grad_scale / loss_grad_rms) + + x_grad_float = x_grad.to(torch.float32) + # scale each element of loss_grad by the absolute value of the corresponding + # element of x_grad, which we view as a noisy estimate of its magnitude for that + # (frame and dimension). later we can consider factored versions. + x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) + x_grad = x_grad_mod.to(x_grad.dtype) + except Exception as e: + logging.info( + f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." + ) + + return x_grad, None, None, None, None, None, None + + +class Balancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + scale_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_abs and max_abs + are violated. + min_abs: the minimum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + max_abs: the maximum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + prob: determines the minimum probability with which we modify the + gradients for the {min,max}_positive and {min,max}_abs constraints, + on each forward(). This is done randomly to prevent all layers + from doing it at the same time. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int, + min_positive: FloatLike = 0.05, + max_positive: FloatLike = 0.95, + min_abs: FloatLike = 0.2, + max_abs: FloatLike = 100.0, + grad_scale: FloatLike = 0.04, + prob: Optional[FloatLike] = None, + ): + super().__init__() + + if prob is None: + prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) + self.prob = prob + # 5% of the time we will return and do nothing because memory usage is + # too high. + self.mem_cutoff = CutoffEstimator(0.05) + + # actually self.num_channels is no longer needed except for an assertion. + self.num_channels = num_channels + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.min_abs = min_abs + self.max_abs = max_abs + self.grad_scale = grad_scale + + def forward(self, x: Tensor) -> Tensor: + if ( + torch.jit.is_scripting() + or not x.requires_grad + or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) + ): + return _no_op(x) + + prob = float(self.prob) + if random.random() < prob: + # The following inner-functions convert from the way we historically specified + # these limitations, as limits on the absolute value and the proportion of positive + # values, to limits on the RMS value and the (mean / stddev). + def _abs_to_rms(x): + # for normally distributed data, if the expected absolute value is x, the + # expected rms value will be sqrt(pi/2) * x. + return 1.25331413732 * x + + def _proportion_positive_to_mean(x): + def _atanh(x): + eps = 1.0e-10 + # eps is to prevent crashes if x is exactly 0 or 1. + # we'll just end up returning a fairly large value. + return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 + + def _approx_inverse_erf(x): + # 1 / (sqrt(pi) * ln(2)), + # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions + # this approximation is extremely crude and gets progressively worse for + # x very close to -1 or +1, but we mostly care about the "middle" region + # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, + # and math.erf(0.0407316414078772) = 0.045935330944660666, + # which is pretty close to 0.05. + return 0.8139535143 * _atanh(x) + + # first convert x from the range 0..1 to the range -1..1 which the error + # function returns + x = -1 + (2 * x) + return _approx_inverse_erf(x) + + min_mean = _proportion_positive_to_mean(float(self.min_positive)) + max_mean = _proportion_positive_to_mean(float(self.max_positive)) + min_rms = _abs_to_rms(float(self.min_abs)) + max_rms = _abs_to_rms(float(self.max_abs)) + grad_scale = float(self.grad_scale) + + assert x.shape[self.channel_dim] == self.num_channels + + return BalancerFunction.apply( + x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim + ) + else: + return _no_op(x) + + +def penalize_abs_values_gt( + x: Tensor, limit: float, penalty: float, name: str = None +) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + + The name is for randomly printed debug info. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss, name) + # you must use x for something, or this will be ineffective. + return x + + +def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. + if x.ndim == 2: + return x.diag() + else: + (batch, dim, dim) = x.shape + x = x.reshape(batch, dim * dim) + x = x[:, :: dim + 1] + assert x.shape == (batch, dim) + return x + + +def _whitening_metric(x: Tensor, num_groups: int): + """ + Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of + of the centered feature covariance are the same within each group's covariance matrix + and also between groups. + Args: + x: a Tensor of shape (*, num_channels) + num_groups: the number of groups of channels, a number >=1 that divides num_channels + Returns: + Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and + greater than 1.0 otherwise. + """ + assert x.dtype != torch.float16 + x = x.reshape(-1, x.shape[-1]) + (num_frames, num_channels) = x.shape + assert num_channels % num_groups == 0 + channels_per_group = num_channels // num_groups + x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) + # x now has shape (num_groups, num_frames, channels_per_group) + # subtract the mean so we use the centered, not uncentered, covariance. + # My experience has been that when we "mess with the gradients" like this, + # it's better not do anything that tries to move the mean around, because + # that can easily cause instability. + x = x - x.mean(dim=1, keepdim=True) + # x_covar: (num_groups, channels_per_group, channels_per_group) + x_covar = torch.matmul(x.transpose(1, 2), x) + x_covar_mean_diag = _diag(x_covar).mean() + # the following expression is what we'd get if we took the matrix product + # of each covariance and measured the mean of its trace, i.e. + # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) + # this metric will be >= 1.0; the larger it is, the less 'white' the data was. + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) + return metric + + +class WhiteningPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: + ctx.save_for_backward(x) + ctx.module = module + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors + w = ctx.module + + try: + with torch.enable_grad(): + with torch.amp.autocast("cuda", enabled=False): + x_detached = x_orig.to(torch.float32).detach() + x_detached.requires_grad = True + + metric = _whitening_metric(x_detached, w.num_groups) + + if random.random() < 0.005 or __name__ == "__main__": + logging.debug( + f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}" + ) + + if metric < float(w.whitening_limit): + w.prob = w.min_prob + return x_grad, None + else: + w.prob = w.max_prob + metric.backward() + penalty_grad = x_detached.grad + scale = w.grad_scale * ( + x_grad.to(torch.float32).norm() + / (penalty_grad.norm() + 1.0e-20) + ) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None + except Exception as e: + logging.info( + f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue." + ) + return x_grad, None + + +class Whiten(nn.Module): + def __init__( + self, + num_groups: int, + whitening_limit: FloatLike, + prob: Union[float, Tuple[float, float]], + grad_scale: FloatLike, + ): + """ + Args: + num_groups: the number of groups to divide the channel dim into before + whitening. We will attempt to make the feature covariance + within each group, after mean subtraction, as "white" as possible, + while having the same trace across all groups. + whitening_limit: a value greater than 1.0, that dictates how much + freedom we have to violate the constraints. 1.0 would mean perfectly + white, with exactly the same trace across groups; larger values + give more freedom. E.g. 2.0. + prob: the probability with which we apply the gradient modification + (also affects the grad scale). May be supplied as a float, + or as a pair (min_prob, max_prob) + + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights. + E.g. 0.02 (you may want to use smaller values than this if prob is large) + """ + super(Whiten, self).__init__() + assert num_groups >= 1 + assert float(whitening_limit) >= 1 + assert grad_scale >= 0 + self.num_groups = num_groups + self.whitening_limit = whitening_limit + self.grad_scale = grad_scale + + if isinstance(prob, float): + prob = (prob, prob) + (self.min_prob, self.max_prob) = prob + assert 0 < self.min_prob <= self.max_prob <= 1 + self.prob = self.max_prob + self.name = None # will be set in training loop + + def forward(self, x: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the input unmodified. + In the backward pass, it will modify the gradients to ensure that the + distribution in each group has close to (lambda times I) as the covariance + after mean subtraction, with the same lambda across groups. + For whitening_limit > 1, there will be more freedom to violate this + constraint. + + Args: + x: the input of shape (*, num_channels) + + Returns: + x, unmodified. You should make sure + you use the returned value, or the graph will be freed + and nothing will happen in backprop. + """ + grad_scale = float(self.grad_scale) + if not x.requires_grad or random.random() > self.prob or grad_scale == 0: + return _no_op(x) + else: + return WhiteningPenaltyFunction.apply(x, self) + + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor, name: str): + ctx.y_shape = y.shape + if random.random() < 0.002 and name is not None: + loss_sum = y.sum().item() + logging.debug(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), + None, + ) + + +def with_loss(x, y, name): + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y, name) + + +class ScaleGradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, alpha: float) -> Tensor: + ctx.alpha = alpha + return x + + @staticmethod + def backward(ctx, grad: Tensor): + return grad * ctx.alpha, None + + +def scale_grad(x: Tensor, alpha: float): + return ScaleGradFunction.apply(x, alpha) + + +class ScaleGrad(nn.Module): + def __init__(self, alpha: float): + super().__init__() + self.alpha = alpha + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return x + return scale_grad(x, self.alpha) + + +class LimitParamValue(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, min: float, max: float): + ctx.save_for_backward(x) + assert max >= min + ctx.min = min + ctx.max = max + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x,) = ctx.saved_tensors + # where x < ctx.min, ensure all grads are negative (this will tend to make + # x more positive). + x_grad = x_grad * torch.where( + torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 + ) + # where x > ctx.max, ensure all grads are positive (this will tend to make + # x more negative). + x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) + return x_grad, None, None + + +def limit_param_value( + x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True +): + # You apply this to (typically) an nn.Parameter during training to ensure that its + # (elements mostly) stays within a supplied range. This is done by modifying the + # gradients in backprop. + # It's not necessary to do this on every batch: do it only some of the time, + # to save a little time. + if training and random.random() < prob: + return LimitParamValue.apply(x, min, max) + else: + return x + + +def _no_op(x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + if x.dtype == torch.float16: + x = x.to(torch.float32) + + s = torch.sigmoid(x - 1.0) + y = x * s + + if requires_grad: + deriv = y * (1 - s) + s + + # notes on derivative of x * sigmoid(x - 1): + # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 + # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund + # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. + # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which + # floors), should be expectation-preserving. + floor = -0.044 + ceil = 1.2 + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.043637 + ceil = 1.2 + + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class DoubleSwish(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x * torch.sigmoid(x - 1.0) + return DoubleSwishFunction.apply(x) + + +# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. +class Dropout2(nn.Module): + def __init__(self, p: FloatLike): + super().__init__() + self.p = p + + def forward(self, x: Tensor) -> Tensor: + return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) + + +class MulForDropout3(torch.autograd.Function): + # returns (x * y * alpha) where alpha is a float and y doesn't require + # grad and is zero-or-one. + @staticmethod + @custom_fwd + def forward(ctx, x, y, alpha): + assert not y.requires_grad + ans = x * y * alpha + ctx.save_for_backward(ans) + ctx.alpha = alpha + return ans + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad): + (ans,) = ctx.saved_tensors + x_grad = ctx.alpha * ans_grad * (ans != 0) + return x_grad, None, None + + +# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, +# and it lets you choose one dimension to share the dropout mask over +class Dropout3(nn.Module): + def __init__(self, p: FloatLike, shared_dim: int): + super().__init__() + self.p = p + self.shared_dim = shared_dim + + def forward(self, x: Tensor) -> Tensor: + p = float(self.p) + if not self.training or p == 0: + return _no_op(x) + scale = 1.0 / (1 - p) + rand_shape = list(x.shape) + rand_shape[self.shared_dim] = 1 + mask = torch.rand(*rand_shape, device=x.device) > p + ans = MulForDropout3.apply(x, mask, scale) + return ans + + +class SwooshLFunction(torch.autograd.Function): + """ + swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + coeff = -0.08 + + with torch.amp.autocast("cuda", enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 + + if not requires_grad: + return y + + y.backward(gradient=torch.ones_like(y)) + + grad = x.grad + floor = coeff + ceil = 1.0 + coeff + 0.005 + + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + + coeff = -0.08 + floor = coeff + ceil = 1.0 + coeff + 0.005 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class SwooshL(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation.""" + if torch.jit.is_scripting() or torch.jit.is_tracing(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 + if not x.requires_grad: + return k2.swoosh_l_forward(x) + else: + return k2.swoosh_l(x) + # return SwooshLFunction.apply(x) + + +class SwooshLOnnx(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation.""" + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 + + +class SwooshRFunction(torch.autograd.Function): + """ + swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + + derivatives are between -0.08 and 0.92. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + + if x.dtype == torch.float16: + x = x.to(torch.float32) + + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + + with torch.amp.autocast("cuda", enabled=False): + with torch.enable_grad(): + x = x.detach() + x.requires_grad = True + y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 + + if not requires_grad: + return y + y.backward(gradient=torch.ones_like(y)) + + grad = x.grad + floor = -0.08 + ceil = 0.925 + + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.08 + ceil = 0.925 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class SwooshR(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-R activation.""" + if torch.jit.is_scripting() or torch.jit.is_tracing(): + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 + if not x.requires_grad: + return k2.swoosh_r_forward(x) + else: + return k2.swoosh_r(x) + # return SwooshRFunction.apply(x) + + +class SwooshROnnx(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-R activation.""" + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687 + + +# simple version of SwooshL that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshLForward(x: Tensor): + x_offset = x - 4.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) + return log_sum - 0.08 * x - 0.035 + + +# simple version of SwooshR that does not redefine the backprop, used in +# ActivationDropoutAndLinearFunction. +def SwooshRForward(x: Tensor): + x_offset = x - 1.0 + log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) + return log_sum - 0.08 * x - 0.313261687 + + +class ActivationDropoutAndLinearFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + activation: str, + dropout_p: float, + dropout_shared_dim: Optional[int], + ): + if dropout_p != 0.0: + dropout_shape = list(x.shape) + if dropout_shared_dim is not None: + dropout_shape[dropout_shared_dim] = 1 + # else it won't be very memory efficient. + dropout_mask = (1.0 / (1.0 - dropout_p)) * ( + torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p + ) + else: + dropout_mask = None + + ctx.save_for_backward(x, weight, bias, dropout_mask) + + ctx.activation = activation + + forward_activation_dict = { + "SwooshL": k2.swoosh_l_forward, + "SwooshR": k2.swoosh_r_forward, + } + # it will raise a KeyError if this fails. This will be an error. We let it + # propagate to the user. + activation_func = forward_activation_dict[activation] + x = activation_func(x) + if dropout_mask is not None: + x = x * dropout_mask + x = torch.nn.functional.linear(x, weight, bias) + return x + + @staticmethod + @custom_bwd + def backward(ctx, ans_grad: Tensor): + saved = ctx.saved_tensors + (x, weight, bias, dropout_mask) = saved + + forward_and_deriv_activation_dict = { + "SwooshL": k2.swoosh_l_forward_and_deriv, + "SwooshR": k2.swoosh_r_forward_and_deriv, + } + # the following lines a KeyError if the activation is unrecognized. + # This will be an error. We let it propagate to the user. + func = forward_and_deriv_activation_dict[ctx.activation] + + y, func_deriv = func(x) + if dropout_mask is not None: + y = y * dropout_mask + # now compute derivative of y w.r.t. weight and bias.. + # y: (..., in_channels), ans_grad: (..., out_channels), + (out_channels, in_channels) = weight.shape + + in_channels = y.shape[-1] + g = ans_grad.reshape(-1, out_channels) + weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) + y_deriv = torch.matmul(ans_grad, weight) + bias_deriv = None if bias is None else g.sum(dim=0) + x_deriv = y_deriv * func_deriv + if dropout_mask is not None: + # order versus func_deriv does not matter + x_deriv = x_deriv * dropout_mask + + return x_deriv, weight_deriv, bias_deriv, None, None, None + + +class ActivationDropoutAndLinear(torch.nn.Module): + """ + This merges an activation function followed by dropout and then a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwooshL and dropout_shared_dim != None, this will be + equivalent to: + nn.Sequential(SwooshL(), + Dropout3(dropout_p, shared_dim=dropout_shared_dim), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + If dropout_shared_dim is None, the dropout would be equivalent to + Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout + mask is smaller. + + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwooshL. + dropout_p: the dropout probability or schedule (happens after nonlinearity). + dropout_shared_dim: the dimension, if any, across which the dropout mask is + shared (e.g. the time dimension). If None, this may be less memory + efficient if there are modules before this one that cache the input + for their backprop (e.g. Balancer or Whiten). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = "SwooshL", + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + initial_scale: float = 1.0, + ): + super().__init__() + # create a temporary module of nn.Linear that we'll steal the + # weights and bias from + l = ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=initial_scale + ) + + self.weight = l.weight + # register_parameter properly handles making it a parameter when l.bias + # is None. I think there is some reason for doing it this way rather + # than just setting it to None but I don't know what it is, maybe + # something to do with exporting the module.. + self.register_parameter("bias", l.bias) + + self.activation = activation + self.dropout_p = dropout_p + self.dropout_shared_dim = dropout_shared_dim + + def forward(self, x: Tensor): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + if self.activation == "SwooshL": + x = SwooshLForward(x) + elif self.activation == "SwooshR": + x = SwooshRForward(x) + else: + assert False, self.activation + return torch.nn.functional.linear(x, self.weight, self.bias) + + return ActivationDropoutAndLinearFunction.apply( + x, + self.weight, + self.bias, + self.activation, + float(self.dropout_p), + self.dropout_shared_dim, + ) + + +def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: + if num_channels <= x.shape[-1]: + return x[..., :num_channels] + else: + shape = list(x.shape) + shape[-1] = num_channels - shape[-1] + zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) + return torch.cat((x, zeros), dim=-1) + + +def _test_whiten(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"_test_whiten(): proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) + x = x.detach() + x.requires_grad = True + m = Balancer( + probs.numel(), + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + min_abs=0.0, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_sign: x = ", x) + print("_test_balancer_sign: y grad = ", y_grad) + print("_test_balancer_sign: x grad = ", x.grad) + + +def _test_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = x.detach() + x.requires_grad = True + m = Balancer( + magnitudes.numel(), + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + min_abs=0.2, + max_abs=0.7, + prob=1.0, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_balancer_magnitude: x = ", x) + print("_test_balancer_magnitude: y grad = ", y_grad) + print("_test_balancer_magnitude: x grad = ", x.grad) + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = DoubleSwish() + + tol = (1.2 - (-0.043637)) / 255.0 + torch.autograd.gradcheck(m, x, atol=tol) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swooshl_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwooshL() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_swooshr_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = SwooshR() + + tol = 1.0 / 255.0 + torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:, 0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:, 0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + +def _test_piecewise_linear(): + p = PiecewiseLinear((0, 10.0)) + for x in [-100, 0, 100]: + assert p(x) == 10.0 + p = PiecewiseLinear((0, 10.0), (1, 0.0)) + for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]: + print("x, y = ", x, y) + assert p(x) == y, (x, p(x), y) + + q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) + x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0] + pq = p.max(q) + for x in x_vals: + y1 = max(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p.min(q) + for x in x_vals: + y1 = min(p(x), q(x)) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + pq = p + q + for x in x_vals: + y1 = p(x) + q(x) + y2 = pq(x) + assert abs(y1 - y2) < 0.001 + + +def _test_activation_dropout_and_linear(): + in_channels = 20 + out_channels = 30 + + for bias in [True, False]: + # actually we don't test for dropout_p != 0.0 because forward functions will give + # different answers. This is because we are using the k2 implementation of + # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() + # internally, messing up the random state. + for dropout_p in [0.0]: + for activation in ["SwooshL", "SwooshR"]: + m1 = nn.Sequential( + SwooshL() if activation == "SwooshL" else SwooshR(), + Dropout3(p=dropout_p, shared_dim=-1), + ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=0.5 + ), + ) + m2 = ActivationDropoutAndLinear( + in_channels, + out_channels, + bias=bias, + initial_scale=0.5, + activation=activation, + dropout_p=dropout_p, + ) + with torch.no_grad(): + m2.weight[:] = m1[2].weight + if bias: + m2.bias[:] = m1[2].bias + # make sure forward gives same result. + x1 = torch.randn(10, in_channels) + x1.requires_grad = True + + # TEMP. + assert torch.allclose( + SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03 + ) + + x2 = x1.clone().detach() + x2.requires_grad = True + seed = 10 + torch.manual_seed(seed) + y1 = m1(x1) + y_grad = torch.randn_like(y1) + y1.backward(gradient=y_grad) + torch.manual_seed(seed) + y2 = m2(x2) + y2.backward(gradient=y_grad) + + print( + f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}" + ) + print("y1 = ", y1) + print("y2 = ", y2) + assert torch.allclose(y1, y2, atol=0.02) + assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) + if bias: + assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) + print("x1.grad = ", x1.grad) + print("x2.grad = ", x2.grad) + + def isclose(a, b): + # return true if cosine similarity is > 0.9. + return (a * b).sum() > 0.9 * ( + (a**2).sum() * (b**2).sum() + ).sqrt() + + # the SwooshL() implementation has a noisy gradient due to 1-byte + # storage of it. + assert isclose(x1.grad, x2.grad) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.DEBUG) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_piecewise_linear() + _test_softmax() + _test_whiten() + _test_balancer_sign() + _test_balancer_magnitude() + _test_double_swish_deriv() + _test_swooshr_deriv() + _test_swooshl_deriv() + _test_activation_dropout_and_linear() diff --git a/egs/zipvoice/zipvoice/solver.py b/egs/zipvoice/zipvoice/solver.py new file mode 100644 index 000000000..a1e316ec8 --- /dev/null +++ b/egs/zipvoice/zipvoice/solver.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch + + +class DiffusionModel(torch.nn.Module): + """A wrapper of diffusion models for inference. + Args: + model: The diffusion model. + distill: Whether it is a distillation model. + """ + + def __init__( + self, + model: torch.nn.Module, + distill: bool = False, + func_name: str = "forward_fm_decoder", + ): + super().__init__() + self.model = model + self.distill = distill + self.func_name = func_name + self.model_func = getattr(self.model, func_name) + + def forward( + self, + t: torch.Tensor, + x: torch.Tensor, + text_condition: torch.Tensor, + speech_condition: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + guidance_scale: Union[float, torch.Tensor] = 0.0, + **kwargs + ) -> torch.Tensor: + """ + Forward function that Handles the classifier-free guidance. + Args: + t: The current timestep, a tensor of shape (batch, 1, 1) or a tensor of a single float. + x: The initial value, with the shape (batch, seq_len, emb_dim). + text_condition: The text_condition of the diffision model, with the shape (batch, seq_len, emb_dim). + speech_condition: The speech_condition of the diffision model, with the shape (batch, seq_len, emb_dim). + padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len). + guidance_scale: The scale of classifier-free guidance, a float or a tensor of shape (batch, 1, 1). + Retrun: + The prediction with the shape (batch, seq_len, emb_dim). + """ + if not torch.is_tensor(guidance_scale): + guidance_scale = torch.tensor( + guidance_scale, dtype=t.dtype, device=t.device + ) + if self.distill: + return self.model_func( + t=t, + xt=x, + text_condition=text_condition, + speech_condition=speech_condition, + padding_mask=padding_mask, + guidance_scale=guidance_scale, + **kwargs + ) + + if (guidance_scale == 0.0).all(): + return self.model_func( + t=t, + xt=x, + text_condition=text_condition, + speech_condition=speech_condition, + padding_mask=padding_mask, + **kwargs + ) + else: + if t.dim() != 0: + t = torch.cat([t] * 2, dim=0) + + x = torch.cat([x] * 2, dim=0) + padding_mask = torch.cat([padding_mask] * 2, dim=0) + + text_condition = torch.cat( + [torch.zeros_like(text_condition), text_condition], dim=0 + ) + + if t.dim() == 0: + if t > 0.5: + speech_condition = torch.cat( + [torch.zeros_like(speech_condition), speech_condition], dim=0 + ) + else: + guidance_scale = guidance_scale * 2 + speech_condition = torch.cat( + [speech_condition, speech_condition], dim=0 + ) + else: + assert t.dim() > 0, t + larger_t_index = (t > 0.5).squeeze(1).squeeze(1) + zero_speech_condition = torch.cat( + [torch.zeros_like(speech_condition), speech_condition], dim=0 + ) + speech_condition = torch.cat( + [speech_condition, speech_condition], dim=0 + ) + speech_condition[larger_t_index] = zero_speech_condition[larger_t_index] + guidance_scale[~larger_t_index[: larger_t_index.size(0) // 2]] *= 2 + + data_uncond, data_cond = self.model_func( + t=t, + xt=x, + text_condition=text_condition, + speech_condition=speech_condition, + padding_mask=padding_mask, + **kwargs + ).chunk(2, dim=0) + + res = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond + return res + + +class EulerSolver: + def __init__( + self, + model: torch.nn.Module, + distill: bool = False, + func_name: str = "forward_fm_decoder", + ): + """Construct a Euler Solver + Args: + model: The diffusion model. + distill: Whether it is distillation model. + """ + + self.model = DiffusionModel(model, distill=distill, func_name=func_name) + + def sample( + self, + x: torch.Tensor, + text_condition: torch.Tensor, + speech_condition: torch.Tensor, + padding_mask: torch.Tensor, + num_step: int = 10, + guidance_scale: Union[float, torch.Tensor] = 0.0, + t_start: Union[float, torch.Tensor] = 0.0, + t_end: Union[float, torch.Tensor] = 1.0, + t_shift: float = 1.0, + **kwargs + ) -> torch.Tensor: + """ + Compute the sample at time `t_end` by Euler Solver. + Args: + x: The initial value at time `t_start`, with the shape (batch, seq_len, emb_dim). + text_condition: The text condition of the diffision mode, with the shape (batch, seq_len, emb_dim). + speech_condition: The speech condition of the diffision model, with the shape (batch, seq_len, emb_dim). + padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len). + num_step: The number of ODE steps. + guidance_scale: The scale for classifier-free guidance, which is + a float or a tensor with the shape (batch, 1, 1). + t_start: the start timestep in the range of [0, 1], + which is a float or a tensor with the shape (batch, 1, 1). + t_end: the end time_step in the range of [0, 1], + which is a float or a tensor with the shape (batch, 1, 1). + t_shift: shift the t toward smaller numbers so that the sampling + will emphasize low SNR region. Should be in the range of (0, 1]. + The shifting will be more significant when the number is smaller. + + Returns: + The approximated solution at time `t_end`. + """ + device = x.device + + if torch.is_tensor(t_start) and t_start.dim() > 0: + timesteps = get_time_steps_batch( + t_start=t_start, + t_end=t_end, + num_step=num_step, + t_shift=t_shift, + device=device, + ) + else: + timesteps = get_time_steps( + t_start=t_start, + t_end=t_end, + num_step=num_step, + t_shift=t_shift, + device=device, + ) + for step in range(num_step): + v = self.model( + t=timesteps[step], + x=x, + text_condition=text_condition, + speech_condition=speech_condition, + padding_mask=padding_mask, + guidance_scale=guidance_scale, + **kwargs + ) + x = x + v * (timesteps[step + 1] - timesteps[step]) + return x + + +def get_time_steps( + t_start: float = 0.0, + t_end: float = 1.0, + num_step: int = 10, + t_shift: float = 1.0, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Compute the intermediate time steps for sampling. + + Args: + t_start: The starting time of the sampling (default is 0). + t_end: The starting time of the sampling (default is 1). + num_step: The number of sampling. + t_shift: shift the t toward smaller numbers so that the sampling + will emphasize low SNR region. Should be in the range of (0, 1]. + The shifting will be more significant when the number is smaller. + device: A torch device. + Returns: + The time step with the shape (num_step + 1,). + """ + + timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device) + + timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps) + + return timesteps + + +def get_time_steps_batch( + t_start: torch.Tensor, + t_end: torch.Tensor, + num_step: int = 10, + t_shift: float = 1.0, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Compute the intermediate time steps for sampling in the batch mode. + + Args: + t_start: The starting time of the sampling (default is 0), with the shape (batch, 1, 1). + t_end: The starting time of the sampling (default is 1), with the shape (batch, 1, 1). + num_step: The number of sampling. + t_shift: shift the t toward smaller numbers so that the sampling + will emphasize low SNR region. Should be in the range of (0, 1]. + The shifting will be more significant when the number is smaller. + device: A torch device. + Returns: + The time step with the shape (num_step + 1, N, 1, 1). + """ + while t_start.dim() > 1 and t_start.size(-1) == 1: + t_start = t_start.squeeze(-1) + while t_end.dim() > 1 and t_end.size(-1) == 1: + t_end = t_end.squeeze(-1) + assert t_start.dim() == t_end.dim() == 1 + + timesteps_shape = (num_step + 1, t_start.size(0)) + timesteps = torch.zeros(timesteps_shape, device=device) + + for i in range(t_start.size(0)): + timesteps[:, i] = torch.linspace(t_start[i], t_end[i], steps=num_step + 1) + + timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps) + + return timesteps.unsqueeze(-1).unsqueeze(-1) diff --git a/egs/zipvoice/zipvoice/tokenizer.py b/egs/zipvoice/zipvoice/tokenizer.py new file mode 100644 index 000000000..87af061e6 --- /dev/null +++ b/egs/zipvoice/zipvoice/tokenizer.py @@ -0,0 +1,570 @@ +# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao +# Han Zhu, +# Wei Kang) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +import unicodedata +from functools import reduce +from typing import Dict, List, Optional + +import cn2an +import inflect +import jieba +from pypinyin import Style, lazy_pinyin +from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials + +try: + from piper_phonemize import phonemize_espeak +except Exception as ex: + raise RuntimeError( + f"{ex}\nPlease run\n" + "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html" + ) + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_fraction_re = re.compile(r"([0-9]+)/([0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [ + (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ("etc", "et cetera"), + ("btw", "by the way"), + ] +] + + +def intersperse(sequence, item=0): + result = [item] * (len(sequence) * 2 + 1) + result[1::2] = sequence + return result + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def _expand_percent(m): + return m.group(1).replace("%", " percent ") + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return " " + match + " dollars " # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return " %s %s " % (dollars, dollar_unit) + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return " %s %s " % (cents, cent_unit) + else: + return " zero dollars " + + +def fraction_to_words(numerator, denominator): + if numerator == 1 and denominator == 2: + return " one half " + if numerator == 1 and denominator == 4: + return " one quarter " + if denominator == 2: + return " " + _inflect.number_to_words(numerator) + " halves " + if denominator == 4: + return " " + _inflect.number_to_words(numerator) + " quarters " + return ( + " " + + _inflect.number_to_words(numerator) + + " " + + _inflect.ordinal(_inflect.number_to_words(denominator)) + + " " + ) + + +def _expand_fraction(m): + numerator = int(m.group(1)) + denominator = int(m.group(2)) + return fraction_to_words(numerator, denominator) + + +def _expand_ordinal(m): + return " " + _inflect.number_to_words(m.group(0)) + " " + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return " two thousand " + elif num > 2000 and num < 2010: + return " two thousand " + _inflect.number_to_words(num % 100) + " " + elif num % 100 == 0: + return " " + _inflect.number_to_words(num // 100) + " hundred " + else: + return ( + " " + + _inflect.number_to_words(num, andword="", zero="oh", group=2).replace( + ", ", " " + ) + + " " + ) + else: + return " " + _inflect.number_to_words(num, andword="") + " " + + +# Normalize numbers pronunciation +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r"\1 pounds", text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_fraction_re, _expand_fraction, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_percent_number_re, _expand_percent, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text + + +# Convert numbers to Chinese pronunciation +def number_to_chinese(text): + text = cn2an.transform(text, "an2cn") + return text + + +def map_punctuations(text): + text = text.replace(",", ",") + text = text.replace("。", ".") + text = text.replace("!", "!") + text = text.replace("?", "?") + text = text.replace(";", ";") + text = text.replace(":", ":") + text = text.replace("、", ",") + text = text.replace("‘", "'") + text = text.replace("“", '"') + text = text.replace("”", '"') + text = text.replace("’", "'") + text = text.replace("⋯", "…") + text = text.replace("···", "…") + text = text.replace("・・・", "…") + text = text.replace("...", "…") + return text + + +def is_chinese(char): + if char >= "\u4e00" and char <= "\u9fa5": + return True + else: + return False + + +def is_alphabet(char): + if (char >= "\u0041" and char <= "\u005a") or ( + char >= "\u0061" and char <= "\u007a" + ): + return True + else: + return False + + +def is_hangul(char): + letters = unicodedata.normalize("NFD", char) + return all( + ["\u1100" <= c <= "\u11ff" or "\u3131" <= c <= "\u318e" for c in letters] + ) + + +def is_japanese(char): + return any( + [ + start <= char <= end + for start, end in [ + ("\u3041", "\u3096"), + ("\u30a0", "\u30ff"), + ("\uff5f", "\uff9f"), + ("\u31f0", "\u31ff"), + ("\u3220", "\u3243"), + ("\u3280", "\u337f"), + ] + ] + ) + + +def get_segment(text: str) -> List[str]: + # sentence --> [ch_part, en_part, ch_part, ...] + # example : + # input : 我们是小米人,是吗? Yes I think so!霍...啦啦啦 + # output : [('我们是小米人,是吗? ', 'zh'), ('Yes I think so!', 'en'), ('霍...啦啦啦', 'zh')] + segments = [] + types = [] + flag = 0 + temp_seg = "" + temp_lang = "" + + for i, ch in enumerate(text): + if is_chinese(ch): + types.append("zh") + elif is_alphabet(ch): + types.append("en") + else: + types.append("other") + + assert len(types) == len(text) + + for i in range(len(types)): + # find the first char of the seg + if flag == 0: + temp_seg += text[i] + temp_lang = types[i] + flag = 1 + else: + if temp_lang == "other": + if types[i] == temp_lang: + temp_seg += text[i] + else: + temp_seg += text[i] + temp_lang = types[i] + else: + if types[i] == temp_lang: + temp_seg += text[i] + elif types[i] == "other": + temp_seg += text[i] + else: + segments.append((temp_seg, temp_lang)) + temp_seg = text[i] + temp_lang = types[i] + flag = 1 + + segments.append((temp_seg, temp_lang)) + return segments + + +def preprocess(text: str) -> str: + text = map_punctuations(text) + return text + + +def tokenize_ZH(text: str) -> List[str]: + try: + text = number_to_chinese(text) + segs = list(jieba.cut(text)) + full = lazy_pinyin( + segs, style=Style.TONE3, tone_sandhi=True, neutral_tone_with_five=True + ) + phones = [] + for x in full: + # valid pinyin (in tone3 style) is alphabet + 1 number in [1-5]. + if not (x[0:-1].isalpha() and x[-1] in ("1", "2", "3", "4", "5")): + phones.append(x) + continue + initial = to_initials(x, strict=False) + # don't want to share tokens with espeak tokens, so use tone3 style + final = to_finals_tone3(x, strict=False, neutral_tone_with_five=True) + if initial != "": + # don't want to share tokens with espeak tokens, so add a '0' after each initial + phones.append(initial + "0") + if final != "": + phones.append(final) + return phones + except: + return [] + + +def tokenize_EN(text: str) -> List[str]: + try: + text = expand_abbreviations(text) + text = normalize_numbers(text) + tokens = phonemize_espeak(text, "en-us") + tokens = reduce(lambda x, y: x + y, tokens) + return tokens + except: + return [] + + +class TokenizerEmilia(object): + def __init__(self, token_file: Optional[str] = None, token_type="phone"): + """ + Args: + tokens: the file that contains information that maps tokens to ids, + which is a text file with '{token} {token_id}' per line. + """ + assert ( + token_type == "phone" + ), f"Only support phone tokenizer for Emilia, but get {token_type}." + self.has_tokens = False + if token_file is None: + logging.debug( + "Initialize Tokenizer without tokens file, will fail when map to ids." + ) + return + self.token2id: Dict[str, int] = {} + with open(token_file, "r", encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split("\t") + token, id = info[0], int(info[1]) + assert token not in self.token2id, token + self.token2id[token] = id + self.pad_id = self.token2id["_"] # padding + + self.vocab_size = len(self.token2id) + self.has_tokens = True + + def texts_to_token_ids( + self, + texts: List[str], + ) -> List[List[int]]: + return self.tokens_to_token_ids(self.texts_to_tokens(texts)) + + def texts_to_tokens( + self, + texts: List[str], + ) -> List[List[str]]: + """ + Args: + texts: + A list of transcripts. + Returns: + Return a list of a list of tokens [utterance][token] + """ + for i in range(len(texts)): + # Text normalization + texts[i] = preprocess(texts[i]) + + phoneme_list = [] + for text in texts: + # now only en and ch + segments = get_segment(text) + all_phoneme = [] + for index in range(len(segments)): + seg = segments[index] + if seg[1] == "zh": + phoneme = tokenize_ZH(seg[0]) + else: + if seg[1] != "en": + logging.warning( + f"The lang should be en, given {seg[1]}, skipping segment : {seg}" + ) + continue + phoneme = tokenize_EN(seg[0]) + all_phoneme += phoneme + phoneme_list.append(all_phoneme) + return phoneme_list + + def tokens_to_token_ids( + self, + tokens: List[List[str]], + intersperse_blank: bool = False, + ) -> List[List[int]]: + """ + Args: + tokens_list: + A list of token list, each corresponding to one utterance. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + + Returns: + Return a list of token id list [utterance][token_id] + """ + assert self.has_tokens, "Please initialize Tokenizer with a tokens file." + token_ids = [] + + for tks in tokens: + ids = [] + for t in tks: + if t not in self.token2id: + logging.warning(f"Skip OOV {t}") + continue + ids.append(self.token2id[t]) + + if intersperse_blank: + ids = intersperse(ids, self.pad_id) + + token_ids.append(ids) + + return token_ids + + +class TokenizerLibriTTS(object): + def __init__(self, token_file: str, token_type: str): + """ + Args: + type: the type of tokenizer, e.g., bpe, char, phone. + tokens: the file that contains information that maps tokens to ids, + which is a text file with '{token} {token_id}' per line if type is + char or phone, otherwise it is a bpe_model file. + """ + self.type = token_type + assert token_type in ["bpe", "char", "phone"] + # Parse token file + + if token_type == "bpe": + import sentencepiece as spm + + self.sp = spm.SentencePieceProcessor() + self.sp.load(token_file) + self.pad_id = self.sp.piece_to_id("") + self.vocab_size = self.sp.get_piece_size() + else: + self.token2id: Dict[str, int] = {} + with open(token_file, "r", encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split("\t") + token, id = info[0], int(info[1]) + assert token not in self.token2id, token + self.token2id[token] = id + self.pad_id = self.token2id["_"] # padding + self.vocab_size = len(self.token2id) + try: + from tacotron_cleaner.cleaners import custom_english_cleaners as cleaner + except Exception as ex: + raise RuntimeError( + f"{ex}\nPlease run\n" + "pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/" + ) + self.cleaner = cleaner + + def texts_to_token_ids( + self, + texts: List[str], + lang: str = "en-us", + ) -> List[List[int]]: + """ + Args: + texts: + A list of transcripts. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + Used when alignment is from MAS. + lang: + Language argument passed to phonemize_espeak(). + + Returns: + Return a list of token id list [utterance][token_id] + """ + for i in range(len(texts)): + # Text normalization + texts[i] = self.cleaner(texts[i]) + + if self.type == "bpe": + token_ids_list = self.sp.encode(texts) + + elif self.type == "phone": + token_ids_list = [] + for text in texts: + tokens_list = phonemize_espeak(text.lower(), lang) + tokens = [] + for t in tokens_list: + tokens.extend(t) + token_ids = [] + for t in tokens: + if t not in self.token2id: + logging.warning(f"Skip OOV {t}") + continue + token_ids.append(self.token2id[t]) + + token_ids_list.append(token_ids) + else: + token_ids_list = [] + for text in texts: + token_ids = [] + for t in text: + if t not in self.token2id: + logging.warning(f"Skip OOV {t}") + continue + token_ids.append(self.token2id[t]) + + token_ids_list.append(token_ids) + + return token_ids_list + + def tokens_to_token_ids( + self, + tokens_list: List[str], + ) -> List[List[int]]: + """ + Args: + tokens_list: + A list of token list, each corresponding to one utterance. + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for tokens in tokens_list: + token_ids = [] + for t in tokens: + if t not in self.token2id: + logging.warning(f"Skip OOV {t}") + continue + token_ids.append(self.token2id[t]) + + token_ids_list.append(token_ids) + + return token_ids_list + + +if __name__ == "__main__": + text = "我们是5年小米人,是吗? Yes I think so! mr king, 5 years, from 2019 to 2024. 霍...啦啦啦超过90%的人咯...?!9204" + tokenizer = Tokenizer() + tokens = tokenizer.texts_to_tokens([text]) + print(f"tokens : {tokens}") + tokens2 = "|".join(tokens[0]) + print(f"tokens2 : {tokens2}") + tokens2 = tokens2.split("|") + assert tokens[0] == tokens2 diff --git a/egs/zipvoice/zipvoice/train_distill.py b/egs/zipvoice/zipvoice/train_distill.py new file mode 100644 index 000000000..ae784050b --- /dev/null +++ b/egs/zipvoice/zipvoice/train_distill.py @@ -0,0 +1,1043 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script trains a ZipVoice-Distill model starting from a ZipVoice model. +It has two distillation stages. + +Usage: + +(1) The first distillation stage with a fixed ZipVoice model as the teacher. +python3 zipvoice/train_distill.py \ + --world-size 8 \ + --use-fp16 1 \ + --tensorboard 1 \ + --dataset "emilia" \ + --base-lr 0.0005 \ + --max-duration 500 \ + --token-file "data/tokens_emilia.txt" \ + --manifest-dir "data/fbank_emilia" \ + --teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \ + --num-updates 60000 \ + --distill-stage "first" \ + --exp-dir zipvoice/exp_zipvoice_distill_1stage + +(2) The second distillation stage with a EMA model as the teacher. +python3 zipvoice/train_distill.py \ + --world-size 8 \ + --use-fp16 1 \ + --tensorboard 1 \ + --dataset "emilia" \ + --base-lr 0.0001 \ + --max-duration 500 \ + --token-file "data/tokens_emilia.txt" \ + --manifest-dir "data/fbank_emilia" \ + --teacher-model zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt \ + --num-updates 2000 \ + --distill-stage "second" \ + --exp-dir zipvoice/exp_zipvoice_distill +""" + +import argparse +import copy +import logging +import os +import random +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from checkpoint import load_checkpoint, save_checkpoint +from lhotse.cut import Cut, CutSet +from lhotse.utils import fix_random_seed +from model import get_distill_model, get_model +from optim import FixedLRScheduler, ScaledAdam +from tokenizer import TokenizerEmilia, TokenizerLibriTTS +from torch import Tensor +from torch.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from train_flow import add_model_arguments, get_params +from tts_datamodule import TtsDataModule +from utils import ( + condition_time_mask, + get_adjusted_batch_count, + prepare_input, + set_batch_count, +) + +from icefall import diagnostics +from icefall.checkpoint import ( + remove_checkpoints, + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + make_pad_mask, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--num-updates", + type=int, + default=0, + help="Number of updates to train, will ignore num_epochs if > 0.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--teacher-model", + type=str, + help="""Checkpoints of pre-trained teacher model""", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipvoice/exp_zipvoice_distill", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.001, help="The base learning rate." + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=50, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--feat-scale", + type=float, + default=0.1, + help="The scale factor of fbank feature", + ) + + parser.add_argument( + "--ema-decay", + type=float, + default=0.9999, + help="The EMA decay factor of target model in distillation.", + ) + parser.add_argument( + "--distill-stage", + type=str, + choices=["first", "second"], + help="The stage of distillation.", + ) + + parser.add_argument( + "--dataset", + type=str, + default="emilia", + choices=["emilia", "libritts"], + help="The used training dataset", + ) + + add_model_arguments(parser) + + return parser + + +def ema(new_model, ema_model, decay): + if isinstance(new_model, DDP): + new_model = new_model.module + if isinstance(ema_model, DDP): + ema_model = ema_model.module + new_model_dict = new_model.state_dict() + ema_model_dict = ema_model.state_dict() + for key in new_model_dict.keys(): + ema_model_dict[key].data.copy_( + ema_model_dict[key].data * decay + new_model_dict[key].data * (1 - decay) + ) + + +def resume_checkpoint( + params: AttributeDict, model: nn.Module, model_avg: nn.Module, model_ema: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, model=model, model_avg=model_avg, model_ema=model_ema, strict=True + ) + + if params.start_epoch > 1: + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def compute_fbank_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + teacher_model: Union[nn.Module, DDP], + features: Tensor, + features_lens: Tensor, + tokens: List[List[int]], + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. + teacher_model: + The teacher model for distillation. + features: + The target acoustic feature. + features_lens: + The number of frames of each utterance. + tokens: + Input tokens that representing the transcripts. + durations: + Duration of each token. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + batch_size, num_frames, _ = features.shape + + features = torch.nn.functional.pad( + features, (0, 0, 0, num_frames - features.size(1)) + ) # (B, T, F) + noise = torch.randn_like(features) # (B, T, F) + + # Sampling t and guidance_scale from uniform distribution + + t_value = random.random() + t = torch.ones(batch_size, 1, 1, device=device) * t_value + if params.distill_stage == "first": + guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2 + else: + guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2 + 1 + xt = features * t + noise * (1 - t) + t_delta_fix = random.uniform(0.0, min(0.3, 1 - t_value)) + t_delta_ema = random.uniform(0.0, min(0.3, 1 - t_value - t_delta_fix)) + t_dest = t + t_delta_fix + t_delta_ema + + with torch.no_grad(): + speech_condition_mask = condition_time_mask( + features_lens=features_lens, + mask_percent=(0.7, 1.0), + max_len=features.size(1), + ) + + if params.distill_stage == "first": + teacher_x_t_mid, _ = teacher_model.sample_intermediate( + tokens=tokens, + features=features, + features_lens=features_lens, + noise=xt, + speech_condition_mask=speech_condition_mask, + t_start=t, + t_end=t + t_delta_fix, + num_step=1, + guidance_scale=guidance_scale, + ) + + target_x1, _ = teacher_model.sample_intermediate( + tokens=tokens, + features=features, + features_lens=features_lens, + noise=teacher_x_t_mid, + speech_condition_mask=speech_condition_mask, + t_start=t + t_delta_fix, + t_end=t_dest, + num_step=1, + guidance_scale=guidance_scale, + ) + else: + teacher_x_t_mid, _ = teacher_model( + tokens=tokens, + features=features, + features_lens=features_lens, + noise=xt, + speech_condition_mask=speech_condition_mask, + t_start=t, + t_end=t + t_delta_fix, + num_step=1, + guidance_scale=guidance_scale, + ) + + target_x1, _ = teacher_model( + tokens=tokens, + features=features, + features_lens=features_lens, + noise=teacher_x_t_mid, + speech_condition_mask=speech_condition_mask, + t_start=t + t_delta_fix, + t_end=t_dest, + num_step=1, + guidance_scale=guidance_scale, + ) + + with torch.set_grad_enabled(is_training): + + pred_x1, _ = model( + tokens=tokens, + features=features, + features_lens=features_lens, + noise=xt, + speech_condition_mask=speech_condition_mask, + t_start=t, + t_end=t_dest, + num_step=1, + guidance_scale=guidance_scale, + ) + pred_v = (pred_x1 - xt) / (t_dest - t) + + padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T) + loss_mask = speech_condition_mask & (~padding_mask) + + target_v = (target_x1 - xt) / (t_dest - t) + loss = torch.mean((pred_v[loss_mask] - target_v[loss_mask]) ** 2) + + ut = features - noise # (B, T, F) + + ref_loss = torch.mean((pred_v[loss_mask] - ut[loss_mask]) ** 2) + + assert loss.requires_grad == is_training + info = MetricsTracker() + num_frames = features_lens.sum().item() + info["frames"] = num_frames + info["loss"] = loss.detach().cpu().item() * num_frames + info["ref_loss"] = ref_loss.detach().cpu().item() * num_frames + return loss, info + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + teacher_model: Union[nn.Module, DDP], + tokenizer: TokenizerEmilia, + optimizer: Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + teacher_model: + The model for distillation. + tokenizer: + Used to convert text to tokens. + optimizer: + The optimizer. + scheduler: + The learning rate scheduler, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + model_ema=teacher_model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params) + 100000) + + if ( + params.valid_interval is None + and batch_idx == 0 + and not params.print_diagnostics + ) or ( + params.valid_interval is not None + and params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + teacher_model=teacher_model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + params.batch_idx_train += 1 + + batch_size = len(batch["text"]) + + tokens, features, features_lens = prepare_input( + params=params, + batch=batch, + device=device, + tokenizer=tokenizer, + return_tokens=True, + return_feature=True, + ) + + try: + with autocast("cuda", enabled=params.use_fp16): + loss, loss_info = compute_fbank_loss( + params=params, + model=model, + teacher_model=teacher_model, + features=features, + features_lens=features_lens, + tokens=tokens, + is_training=True, + ) + + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + scaler.scale(loss).backward() + + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + if params.distill_stage == "second": + ema(model, teacher_model, params.ema_decay) + except RuntimeError as e: + if "out of memory" in str(e): + logging.info(f"out of memory error at rank {rank}") + # optimizer.zero_grad() + # duration_optimizer.zero_grad() + torch.cuda.empty_cache() + raise + continue + else: + logging.info(f"Caught exception : {e}.") + save_bad_model() + raise + except Exception as e: + logging.info(f"Caught exception : {e}.") + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + if ( + params.batch_idx_train > 0 + and params.num_updates > 0 + and params.batch_idx_train > params.num_updates + ): + break + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 1024.0 or ( + cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + loss_value = tot_loss["loss"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + teacher_model: Optional[nn.Module], + tokenizer: TokenizerEmilia, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + tokens, features, features_lens = prepare_input( + params=params, + batch=batch, + device=device, + tokenizer=tokenizer, + return_tokens=True, + return_feature=True, + ) + + loss, loss_info = compute_fbank_loss( + params=params, + model=model, + teacher_model=teacher_model, + features=features, + features_lens=features_lens, + tokens=tokens, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + os.makedirs(f"{params.exp_dir}/fbank", exist_ok=True) + + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + if params.dataset == "emilia": + tokenizer = TokenizerEmilia( + token_file=params.token_file, token_type=params.token_type + ) + elif params.dataset == "libritts": + tokenizer = TokenizerLibriTTS( + token_file=params.token_file, token_type=params.token_type + ) + + params.vocab_size = tokenizer.vocab_size + params.pad_id = tokenizer.pad_id + + params.device = device + + logging.info(params) + + logging.info("About to create model") + + assert params.teacher_model is not None + logging.info(f"Loading pre-trained model from {params.teacher_model}") + model = get_distill_model(params) + _ = load_checkpoint( + filename=params.teacher_model, + model=model, + strict=(params.distill_stage == "second"), + ) + + if params.distill_stage == "first": + teacher_model = get_model(params) + _ = load_checkpoint( + filename=params.teacher_model, model=teacher_model, strict=True + ) + else: + teacher_model = copy.deepcopy(model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of parameters : {num_param}") + + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + assert params.start_epoch > 0, params.start_epoch + if params.start_epoch > 1: + logging.info(f"Resuming from epoch {params.start_epoch}") + if params.distill_stage == "first": + checkpoints = resume_checkpoint( + params=params, model=model, model_avg=model_avg + ) + else: + checkpoints = resume_checkpoint( + params=params, model=model, model_avg=model_avg, model_ema=teacher_model + ) + + model = model.to(device) + teacher_model.to(device) + teacher_model.eval() + + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + # only update the fm_decoder + num_trainable = 0 + for name, p in model.named_parameters(): + if "fm_decoder" in name: + p.requires_grad = True + num_trainable += p.numel() + else: + p.requires_grad = False + + logging.info( + "A total of {} trainable parameters ({:.3f}% of the whole model)".format( + num_trainable, num_trainable / num_param * 100 + ) + ) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs( + model, + lr=params.base_lr, + include_names=True, + ), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = FixedLRScheduler(optimizer) + + scaler = GradScaler("cuda", enabled=params.use_fp16) + + if params.start_epoch > 1 and checkpoints is not None: + # load state_dict for optimizers + if "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + # load state_dict for schedulers + if "scheduler" in checkpoints: + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt_emilia(c: Cut): + if c.duration < 1.0 or c.duration > 30.0: + return False + return True + + def remove_short_and_long_utt_libritts(c: Cut): + if c.duration < 1.0 or c.duration > 20.0: + return False + return True + + datamodule = TtsDataModule(args) + if params.dataset == "emilia": + train_cuts = CutSet.mux( + datamodule.train_emilia_EN_cuts(), + datamodule.train_emilia_ZH_cuts(), + weights=[46000, 49000], + ) + train_cuts = train_cuts.filter(remove_short_and_long_utt_emilia) + dev_cuts = CutSet.mux( + datamodule.dev_emilia_EN_cuts(), + datamodule.dev_emilia_ZH_cuts(), + weights=[0.5, 0.5], + ) + elif params.dataset == "libritts": + train_cuts = datamodule.train_libritts_cuts() + train_cuts = train_cuts.filter(remove_short_and_long_utt_libritts) + dev_cuts = datamodule.dev_libritts_cuts() + + train_dl = datamodule.train_dataloaders(train_cuts) + + valid_dl = datamodule.dev_dataloaders(dev_cuts) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + teacher_model=teacher_model, + tokenizer=tokenizer, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + model_avg=model_avg, + model_ema=teacher_model, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + TtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/zipvoice/zipvoice/train_flow.py b/egs/zipvoice/zipvoice/train_flow.py new file mode 100644 index 000000000..74d81b726 --- /dev/null +++ b/egs/zipvoice/zipvoice/train_flow.py @@ -0,0 +1,1108 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang, +# Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script trains a ZipVoice model with the flow-matching loss. + +Usage: + +python3 zipvoice/train_flow.py \ + --world-size 8 \ + --use-fp16 1 \ + --dataset emilia \ + --max-duration 500 \ + --lr-hours 30000 \ + --lr-batches 7500 \ + --token-file "data/tokens_emilia.txt" \ + --manifest-dir "data/fbank_emilia" \ + --num-epochs 11 \ + --exp-dir zipvoice/exp_zipvoice +""" + +import argparse +import copy +import logging +import os +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from checkpoint import load_checkpoint, save_checkpoint +from lhotse.cut import Cut, CutSet +from lhotse.utils import fix_random_seed +from model import get_model +from optim import Eden, ScaledAdam +from tokenizer import TokenizerEmilia, TokenizerLibriTTS +from torch import Tensor +from torch.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import TtsDataModule +from utils import get_adjusted_batch_count, prepare_input, set_batch_count + +from icefall import diagnostics +from icefall.checkpoint import ( + remove_checkpoints, + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--fm-decoder-downsampling-factor", + type=str, + default="1,2,4,2,1", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--fm-decoder-num-layers", + type=str, + default="2,2,4,4,4", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--fm-decoder-cnn-module-kernel", + type=str, + default="31,15,7,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--fm-decoder-feedforward-dim", + type=int, + default=1536, + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--fm-decoder-num-heads", + type=int, + default=4, + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--fm-decoder-dim", + type=int, + default=512, + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--text-encoder-downsampling-factor", + type=str, + default="1", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--text-encoder-num-layers", + type=str, + default="4", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--text-encoder-feedforward-dim", + type=int, + default=512, + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--text-encoder-cnn-module-kernel", + type=str, + default="9", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--text-encoder-num-heads", + type=int, + default=4, + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--text-encoder-dim", + type=int, + default=192, + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=int, + default=32, + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=int, + default=12, + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=int, + default=4, + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default=48, + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--time-embed-dim", + type=int, + default=192, + help="Embedding dimension of timestamps embedding.", + ) + + parser.add_argument( + "--text-embed-dim", + type=int, + default=192, + help="Embedding dimension of text embedding.", + ) + + parser.add_argument( + "--token-type", + type=str, + default="phone", + choices=["phone", "char", "bpe"], + help="Input token type of TTS model, by default, " + "we use phone for emilia, char for libritts.", + ) + + parser.add_argument( + "--token-file", + type=str, + default="data/tokens_emilia.txt", + help="The file that contains information that maps tokens to ids," + "which is a text file with '{token}\t{token_id}' per line if type is" + "char or phone, otherwise it is a bpe_model file.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=11, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="""Checkpoints of pre-trained models, will load it if not None + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipvoice/exp_zipvoice", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.02, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=10, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--lr-hours", + type=float, + default=0, + help="""If positive, --epoch is ignored and it specifies the number of hours + that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=50, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--feat-scale", + type=float, + default=0.1, + help="The scale factor of fbank feature", + ) + + parser.add_argument( + "--condition-drop-ratio", + type=float, + default=0.2, + help="The drop rate of text condition during training.", + ) + + parser.add_argument( + "--dataset", + type=str, + default="emilia", + choices=["emilia", "libritts"], + help="The used training dataset", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - sampling_rate: Sampling rate of the wavform. + + - frame_shift_ms: Frame shift in milliseconds. + + - feat_dim: The model input dim. It has to match the one used + in computing features. + + - env_info: A dict containing information about the environment. + + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 4000, + "sampling_rate": 24000, + "frame_shift_ms": 256 / 24000 * 1000, + "feat_dim": 100, + "env_info": get_env_info(), + } + ) + + return params + + +def resume_checkpoint( + params: AttributeDict, model: nn.Module, model_avg: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + logging.info(f"Resuming from file {filename}") + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, model=model, model_avg=model_avg, strict=True + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def compute_fbank_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + features: Tensor, + features_lens: Tensor, + tokens: List[List[int]], + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. + features: + The target acoustic feature. + features_lens: + The number of frames of each utterance. + tokens: + Input tokens that representing the transcripts. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + batch_size, num_frames, _ = features.shape + + features = torch.nn.functional.pad( + features, (0, 0, 0, num_frames - features.size(1)) + ) # (B, T, F) + noise = torch.randn_like(features) # (B, T, F) + + # Sampling t from uniform distribution + if is_training: + t = torch.rand(batch_size, 1, 1, device=device) + else: + t = ( + (torch.arange(batch_size, device=device) / batch_size) + .unsqueeze(1) + .unsqueeze(2) + ) + with torch.set_grad_enabled(is_training): + + loss = model( + tokens=tokens, + features=features, + features_lens=features_lens, + noise=noise, + t=t, + condition_drop_ratio=params.condition_drop_ratio, + ) + + assert loss.requires_grad == is_training + info = MetricsTracker() + num_frames = features_lens.sum().item() + info["frames"] = num_frames + info["loss"] = loss.detach().cpu().item() * num_frames + + return loss, info + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: TokenizerEmilia, + optimizer: Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + tokenizer: + Used to convert text to tokens. + optimizer: + The optimizer. + scheduler: + The learning rate scheduler, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + if ( + params.valid_interval is None + and batch_idx == 0 + and not params.print_diagnostics + ) or ( + params.valid_interval is not None + and params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + params.batch_idx_train += 1 + + batch_size = len(batch["text"]) + + tokens, features, features_lens = prepare_input( + params=params, + batch=batch, + device=device, + tokenizer=tokenizer, + return_tokens=True, + return_feature=True, + ) + + try: + with autocast("cuda", enabled=params.use_fp16): + loss, loss_info = compute_fbank_loss( + params=params, + model=model, + features=features, + features_lens=features_lens, + tokens=tokens, + is_training=True, + ) + + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + scaler.scale(loss).backward() + + scheduler.step_batch(params.batch_idx_train) + # Use the number of hours of speech to adjust the learning rate + if params.lr_hours > 0: + scheduler.step_epoch( + params.batch_idx_train + * params.max_duration + * params.world_size + / 3600 + ) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except RuntimeError as e: + if "out of memory" in str(e): + logging.info(f"out of memory error at rank {rank}") + # optimizer.zero_grad() + # duration_optimizer.zero_grad() + torch.cuda.empty_cache() + raise + continue + else: + logging.info(f"Caught exception : {e}.") + save_bad_model() + raise + except Exception as e: + logging.info(f"Caught exception : {e}.") + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 1024.0 or ( + cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + loss_value = tot_loss["loss"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: TokenizerEmilia, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + tokens, features, features_lens = prepare_input( + params=params, + batch=batch, + device=device, + tokenizer=tokenizer, + return_tokens=True, + return_feature=True, + ) + + loss, loss_info = compute_fbank_loss( + params=params, + model=model, + features=features, + features_lens=features_lens, + tokens=tokens, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + os.makedirs(f"{params.exp_dir}/fbank", exist_ok=True) + + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + if params.dataset == "emilia": + tokenizer = TokenizerEmilia( + token_file=params.token_file, token_type=params.token_type + ) + elif params.dataset == "libritts": + tokenizer = TokenizerLibriTTS( + token_file=params.token_file, token_type=params.token_type + ) + params.vocab_size = tokenizer.vocab_size + params.pad_id = tokenizer.pad_id + + params.device = device + + logging.info(params) + + logging.info("About to create model") + + model = get_model(params) + if params.checkpoint is not None: + logging.info(f"Loading pre-trained model from {params.checkpoint}") + _ = load_checkpoint(filename=params.checkpoint, model=model, strict=True) + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of parameters : {num_param}") + + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + assert params.start_epoch > 0, params.start_epoch + if params.start_epoch > 1: + checkpoints = resume_checkpoint(params=params, model=model, model_avg=model_avg) + + model = model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs( + model, + lr=params.base_lr, + include_names=True, + ), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + assert params.lr_hours >= 0 + if params.lr_hours > 0: + scheduler = Eden(optimizer, params.lr_batches, params.lr_hours) + else: + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + scaler = GradScaler("cuda", enabled=params.use_fp16) + + if params.start_epoch > 1 and checkpoints is not None: + # load state_dict for optimizers + if "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + # load state_dict for schedulers + if "scheduler" in checkpoints: + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt_emilia(c: Cut): + if c.duration < 1.0 or c.duration > 30.0: + return False + return True + + def remove_short_and_long_utt_libritts(c: Cut): + if c.duration < 1.0 or c.duration > 20.0: + return False + return True + + datamodule = TtsDataModule(args) + if params.dataset == "emilia": + train_cuts = CutSet.mux( + datamodule.train_emilia_EN_cuts(), + datamodule.train_emilia_ZH_cuts(), + weights=[46000, 49000], + ) + train_cuts = train_cuts.filter(remove_short_and_long_utt_emilia) + dev_cuts = CutSet.mux( + datamodule.dev_emilia_EN_cuts(), + datamodule.dev_emilia_ZH_cuts(), + weights=[0.5, 0.5], + ) + elif params.dataset == "libritts": + train_cuts = datamodule.train_libritts_cuts() + train_cuts = train_cuts.filter(remove_short_and_long_utt_libritts) + dev_cuts = datamodule.dev_libritts_cuts() + + train_dl = datamodule.train_dataloaders(train_cuts) + + valid_dl = datamodule.dev_dataloaders(dev_cuts) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + if params.lr_hours == 0: + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + tokenizer=tokenizer, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + TtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/zipvoice/zipvoice/tts_datamodule.py b/egs/zipvoice/zipvoice/tts_datamodule.py new file mode 100644 index 000000000..e8ea7a4eb --- /dev/null +++ b/egs/zipvoice/zipvoice/tts_datamodule.py @@ -0,0 +1,456 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao, +# Zengrui Jin, +# Han Zhu, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import torch +from feature import TorchAudioFbank, TorchAudioFbankConfig +from lhotse import CutSet, load_manifest_lazy, validate +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, +) +from lhotse.dataset.collation import collate_audio +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + BatchIO, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed, ifnone +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +SAMPLING_RATE = 24000 + + +class TtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank_emilia"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=100, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=8, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = SAMPLING_RATE + config = TorchAudioFbankConfig( + sampling_rate=sampling_rate, + n_mels=100, + n_fft=1024, + hop_length=256, + ) + train = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = SAMPLING_RATE + config = TorchAudioFbankConfig( + sampling_rate=sampling_rate, + n_mels=100, + n_fft=1024, + hop_length=256, + ) + validate = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + dev_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + dev_dl = DataLoader( + validate, + sampler=dev_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return dev_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = SAMPLING_RATE + config = TorchAudioFbankConfig( + sampling_rate=sampling_rate, + n_mels=100, + n_fft=1024, + hop_length=256, + ) + test = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)), + return_cuts=self.args.return_cuts, + return_audio=True, + ) + else: + test = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + return_audio=True, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_emilia_EN_cuts(self) -> CutSet: + logging.info("About to get train the EN subset") + return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_EN.jsonl.gz") + + @lru_cache() + def train_emilia_ZH_cuts(self) -> CutSet: + logging.info("About to get train the ZH subset") + return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_ZH.jsonl.gz") + + @lru_cache() + def dev_emilia_EN_cuts(self) -> CutSet: + logging.info("About to get dev the EN subset") + return load_manifest_lazy( + self.args.manifest_dir / "emilia_cuts_EN-dev.jsonl.gz" + ) + + @lru_cache() + def dev_emilia_ZH_cuts(self) -> CutSet: + logging.info("About to get dev the ZH subset") + return load_manifest_lazy( + self.args.manifest_dir / "emilia_cuts_ZH-dev.jsonl.gz" + ) + + @lru_cache() + def train_libritts_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_libritts_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_dev-clean.jsonl.gz" + ) + + +class SpeechSynthesisDataset(torch.utils.data.Dataset): + """ + The PyTorch Dataset for the speech synthesis task. + Each item in this dataset is a dict of: + + .. code-block:: + + { + 'audio': (B x NumSamples) float tensor + 'features': (B x NumFrames x NumFeatures) float tensor + 'audio_lens': (B, ) int tensor + 'features_lens': (B, ) int tensor + 'text': List[str] of len B # when return_text=True + 'tokens': List[List[str]] # when return_tokens=True + 'speakers': List[str] of len B # when return_spk_ids=True + 'cut': List of Cuts # when return_cuts=True + } + """ + + def __init__( + self, + cut_transforms: List[Callable[[CutSet], CutSet]] = None, + feature_input_strategy: BatchIO = PrecomputedFeatures(), + feature_transforms: Union[Sequence[Callable], Callable] = None, + return_text: bool = True, + return_tokens: bool = False, + return_spk_ids: bool = False, + return_cuts: bool = False, + return_audio: bool = False, + ) -> None: + super().__init__() + + self.cut_transforms = ifnone(cut_transforms, []) + self.feature_input_strategy = feature_input_strategy + + self.return_text = return_text + self.return_tokens = return_tokens + self.return_spk_ids = return_spk_ids + self.return_cuts = return_cuts + self.return_audio = return_audio + + if feature_transforms is None: + feature_transforms = [] + elif not isinstance(feature_transforms, Sequence): + feature_transforms = [feature_transforms] + + assert all( + isinstance(transform, Callable) for transform in feature_transforms + ), "Feature transforms must be Callable" + self.feature_transforms = feature_transforms + + def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: + validate_for_tts(cuts) + + for transform in self.cut_transforms: + cuts = transform(cuts) + + features, features_lens = self.feature_input_strategy(cuts) + + for transform in self.feature_transforms: + features = transform(features) + + batch = { + "features": features, + "features_lens": features_lens, + } + + if self.return_audio: + audio, audio_lens = collate_audio(cuts) + batch["audio"] = audio + batch["audio_lens"] = audio_lens + + if self.return_text: + # use normalized text + text = [cut.supervisions[0].normalized_text for cut in cuts] + batch["text"] = text + + if self.return_tokens: + tokens = [cut.tokens for cut in cuts] + batch["tokens"] = tokens + + if self.return_spk_ids: + batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts] + + if self.return_cuts: + batch["cut"] = [cut for cut in cuts] + + return batch + + +def validate_for_tts(cuts: CutSet) -> None: + validate(cuts) + for cut in cuts: + assert ( + len(cut.supervisions) == 1 + ), "Only the Cuts with single supervision are supported." diff --git a/egs/zipvoice/zipvoice/utils.py b/egs/zipvoice/zipvoice/utils.py new file mode 100644 index 000000000..4092d0ae4 --- /dev/null +++ b/egs/zipvoice/zipvoice/utils.py @@ -0,0 +1,219 @@ +from typing import Any, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn.parallel import DistributedDataParallel as DDP + + +class AttributeDict(dict): + def __getattr__(self, key): + if key in self: + return self[key] + raise AttributeError(f"No such attribute '{key}'") + + def __setattr__(self, key, value): + self[key] = value + + def __delattr__(self, key): + if key in self: + del self[key] + return + raise AttributeError(f"No such attribute '{key}'") + + +def prepare_input( + params: AttributeDict, + batch: dict, + device: torch.device, + tokenizer: Optional[Any] = None, + return_tokens: bool = False, + return_feature: bool = False, + return_audio: bool = False, + return_prompt: bool = False, +): + """ + Parse the features and targets of the current batch. + Args: + params: + It is returned by :func:`get_params`. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + sp: + Used to convert text to bpe tokens. + device: + The device of Tensor. + """ + return_list = [] + + if return_tokens: + assert tokenizer is not None + + if params.token_type == "phone": + tokens = tokenizer.tokens_to_token_ids(batch["tokens"]) + else: + tokens = tokenizer.texts_to_token_ids(batch["text"]) + return_list += [tokens] + + if return_feature: + features = batch["features"].to(device) + features_lens = batch["features_lens"].to(device) + return_list += [features * params.feat_scale, features_lens] + + if return_audio: + return_list += [batch["audio"], batch["audio_lens"]] + + if return_prompt: + if return_tokens: + if params.token_type == "phone": + prompt_tokens = tokenizer.tokens_to_token_ids(batch["prompt"]["tokens"]) + else: + prompt_tokens = tokenizer.texts_to_token_ids(batch["prompt"]["text"]) + return_list += [prompt_tokens] + if return_feature: + prompt_features = batch["prompt"]["features"].to(device) + prompt_features_lens = batch["prompt"]["features_lens"].to(device) + return_list += [prompt_features * params.feat_scale, prompt_features_lens] + if return_audio: + return_list += [batch["prompt"]["audio"], batch["prompt"]["audio_lens"]] + + return return_list + + +def prepare_avg_tokens_durations(features_lens, tokens_lens): + tokens_durations = [] + for i in range(len(features_lens)): + utt_duration = features_lens[i] + avg_token_duration = utt_duration // tokens_lens[i] + tokens_durations.append([avg_token_duration] * tokens_lens[i]) + return tokens_durations + + +def pad_labels(y: List[List[int]], pad_id: int, device: torch.device): + """ + Pad the transcripts to the same length with zeros. + + Args: + y: the transcripts, which is a list of a list + + Returns: + Return a Tensor of padded transcripts. + """ + y = [l + [pad_id] for l in y] + length = max([len(l) for l in y]) + y = [l + [pad_id] * (length - len(l)) for l in y] + return torch.tensor(y, dtype=torch.int64, device=device) + + +def get_tokens_index(durations: List[List[int]], num_frames: int) -> torch.Tensor: + """ + Gets position in the transcript for each frame, i.e. the position + in the symbol-sequence to look up. + + Args: + durations: + Duration of each token in transcripts. + num_frames: + The maximum frame length of the current batch. + + Returns: + Return a Tensor of shape (batch_size, num_frames) + """ + durations = [x + [num_frames - sum(x)] for x in durations] + batch_size = len(durations) + ans = torch.zeros(batch_size, num_frames, dtype=torch.int64) + for b in range(batch_size): + this_dur = durations[b] + cur_frame = 0 + for i, d in enumerate(this_dur): + ans[b, cur_frame : cur_frame + d] = i + cur_frame += d + assert cur_frame == num_frames, (cur_frame, num_frames) + return ans + + +def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def condition_time_mask( + features_lens: torch.Tensor, mask_percent: Tuple[float, float], max_len: int = 0 +) -> torch.Tensor: + """ + Apply Time masking. + Args: + features_lens: + input tensor of shape ``(B)`` + mask_size: + the width size for masking. + max_len: + the maximum length of the mask. + Returns: + Return a 2-D bool tensor (B, T), where masked positions + are filled with `True` and non-masked positions are + filled with `False`. + """ + mask_size = ( + torch.zeros_like(features_lens, dtype=torch.float32).uniform_(*mask_percent) + * features_lens + ).to(torch.int64) + mask_starts = ( + torch.rand_like(mask_size, dtype=torch.float32) * (features_lens - mask_size) + ).to(torch.int64) + mask_ends = mask_starts + mask_size + max_len = max(max_len, features_lens.max()) + seq_range = torch.arange(0, max_len, device=features_lens.device) + mask = (seq_range[None, :] >= mask_starts[:, None]) & ( + seq_range[None, :] < mask_ends[:, None] + ) + return mask + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """ + Args: + lengths: + A 1-D tensor containing sentence lengths. + max_len: + The length of masks. + Returns: + Return a 2-D bool tensor, where masked positions + are filled with `True` and non-masked positions are + filled with `False`. + + >>> lengths = torch.tensor([1, 3, 2, 5]) + >>> make_pad_mask(lengths) + tensor([[False, True, True, True, True], + [False, False, False, True, True], + [False, False, True, True, True], + [False, False, False, False, False]]) + """ + assert lengths.ndim == 1, lengths.ndim + max_len = max(max_len, lengths.max()) + n = lengths.size(0) + seq_range = torch.arange(0, max_len, device=lengths.device) + expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) + + return expaned_lengths >= lengths.unsqueeze(-1) diff --git a/egs/zipvoice/zipvoice/zipformer.py b/egs/zipvoice/zipvoice/zipformer.py new file mode 100644 index 000000000..190191cbb --- /dev/null +++ b/egs/zipvoice/zipvoice/zipformer.py @@ -0,0 +1,1648 @@ +#!/usr/bin/env python3 +# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao, +# Wei Kang +# Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +from typing import Optional, Tuple, Union + +import torch +from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. +) +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationDropoutAndLinear, + Balancer, + BiasNorm, + Dropout2, + FloatLike, + ScheduledFloat, + SwooshR, + Whiten, + limit_param_value, + penalize_abs_values_gt, + softmax, +) +from torch import Tensor, nn + + +def timestep_embedding(timesteps, dim, max_period=10000): + """Create sinusoidal timestep embeddings. + + :param timesteps: shape of (N) or (N, T) + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim) + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) + / half + ) + + if timesteps.dim() == 2: + timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N) + + args = timesteps[..., None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1) + return embedding + + +class TTSZipformer(nn.Module): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per + attention head + value_head_dim (int or Tuple[int]): dimension of value in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + pos_dim (int): the dimension of each positional-encoding vector prior to projection, + e.g. 128. + + dropout (float): dropout rate + warmup_batches (float): number of batches to warm up over; this controls + dropout of encoder layers. + use_time_embed: (bool): if True, do not take time embedding as additional input. + time_embed_dim: (int): the dimension of the time embedding. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + downsampling_factor: Tuple[int] = (2, 4), + num_encoder_layers: Union[int, Tuple[int]] = 4, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + encoder_dim: int = 384, + query_head_dim: int = 24, + pos_head_dim: int = 4, + value_head_dim: int = 12, + num_heads: int = 8, + feedforward_dim: int = 1536, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + warmup_batches: float = 4000.0, + use_time_embed: bool = True, + time_embed_dim: int = 192, + use_guidance_scale_embed: bool = False, + guidance_scale_embed_dim: int = 192, + use_conv: bool = True, + ) -> None: + super(TTSZipformer, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + def _assert_downsampling_factor(factors): + """assert downsampling_factor follows u-net style""" + assert factors[0] == 1 and factors[-1] == 1 + + for i in range(1, len(factors) // 2 + 1): + assert factors[i] == factors[i - 1] * 2 + + for i in range(len(factors) // 2 + 1, len(factors)): + assert factors[i] * 2 == factors[i - 1] + + _assert_downsampling_factor(downsampling_factor) + self.downsampling_factor = downsampling_factor # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + self.encoder_dim = encoder_dim + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim + self.value_head_dim = value_head_dim + self.num_heads = num_heads + + self.use_time_embed = use_time_embed + self.use_guidance_scale_embed = use_guidance_scale_embed + + self.time_embed_dim = time_embed_dim + if self.use_time_embed: + assert time_embed_dim != -1 + else: + time_embed_dim = -1 + self.guidance_scale_embed_dim = guidance_scale_embed_dim + + self.in_proj = nn.Linear(in_dim, encoder_dim) + self.out_proj = nn.Linear(encoder_dim, out_dim) + + # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder + encoders = [] + + num_encoders = len(downsampling_factor) + for i in range(num_encoders): + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + value_head_dim=value_head_dim, + feedforward_dim=feedforward_dim, + use_conv=use_conv, + cnn_module_kernel=cnn_module_kernel[i], + dropout=dropout, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + embed_dim=encoder_dim, + time_embed_dim=time_embed_dim, + pos_dim=pos_dim, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + ) + + if downsampling_factor[i] != 1: + encoder = DownsampledZipformer2Encoder( + encoder, + dim=encoder_dim, + downsample=downsampling_factor[i], + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + if self.use_time_embed: + self.time_embed = nn.Sequential( + nn.Linear(time_embed_dim, time_embed_dim * 2), + SwooshR(), + nn.Linear(time_embed_dim * 2, time_embed_dim), + ) + else: + self.time_embed = None + + if self.use_guidance_scale_embed: + self.guidance_scale_embed = ScaledLinear( + guidance_scale_embed_dim, time_embed_dim, bias=False, initial_scale=0.1 + ) + else: + self.guidance_scale_embed = None + + def forward( + self, + x: Tensor, + t: Optional[Tensor] = None, + padding_mask: Optional[Tensor] = None, + guidance_scale: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + t: + A t tensor of shape (batch_size,) or (batch_size, seq_len) + padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return the output embeddings. its shape is (batch_size, output_seq_len, encoder_dim) + """ + x = x.permute(1, 0, 2) + x = self.in_proj(x) + + if t is not None: + assert t.dim() == 1 or t.dim() == 2, t.shape + time_emb = timestep_embedding(t, self.time_embed_dim) + if guidance_scale is not None: + assert ( + guidance_scale.dim() == 1 or guidance_scale.dim() == 2 + ), guidance_scale.shape + guidance_scale_emb = self.guidance_scale_embed( + timestep_embedding(guidance_scale, self.guidance_scale_embed_dim) + ) + time_emb = time_emb + guidance_scale_emb + time_emb = self.time_embed(time_emb) + else: + time_emb = None + + attn_mask = None + + for i, module in enumerate(self.encoders): + x = module( + x, + time_emb=time_emb, + src_key_padding_mask=padding_mask, + attn_mask=attn_mask, + ) + x = self.out_proj(x) + x = x.permute(1, 0, 2) + return x + + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (required). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module (default=31). + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + use_conv: bool = True, + attention_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + conv_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + const_attention_rate: FloatLike = ScheduledFloat( + (0.0, 0.25), (4000.0, 0.025), default=0 + ), + ff2_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + ff3_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + bypass_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.5), (4000.0, 0.02), default=0 + ), + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + + # self.bypass implements layer skipping as well as bypass; see its default values. + self.bypass = BypassModule( + embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 + ) + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) + + # skip probability for dynamic modules (meaning: anything but feedforward). + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) + # an additional skip probability that applies to ConvModule to stop it from + # contributing too much early on. + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + + # ff2_skip_rate is to prevent the ff2 module from having output that's too big + # compared to its residual. + self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) + self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) + + self.const_attention_rate = copy.deepcopy(const_attention_rate) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + dropout=0.0, + ) + + self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.feed_forward1 = FeedforwardModule( + embed_dim, (feedforward_dim * 3) // 4, dropout + ) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule( + embed_dim, (feedforward_dim * 5) // 4, dropout + ) + + self.nonlin_attention = NonlinAttention( + embed_dim, hidden_channels=3 * embed_dim // 4 + ) + + self.use_conv = use_conv + + if self.use_conv: + self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel) + + self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel) + + self.norm = BiasNorm(embed_dim) + + self.balancer1 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.2, + max_abs=4.0, + ) + + # balancer for output of NonlinAttentionModule + self.balancer_na = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), + prob=0.05, # out of concern for memory usage + ) + + # balancer for output of feedforward2, prevent it from staying too + # small. give this a very small probability, even at the start of + # training, it's to fix a rare problem and it's OK to fix it slowly. + self.balancer_ff2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), + max_abs=2.0, + prob=0.05, + ) + + self.balancer_ff3 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), + max_abs=4.0, + prob=0.05, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(4.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.balancer2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.1, + max_abs=4.0, + ) + + def get_sequence_dropout_mask( + self, x: Tensor, dropout_rate: float + ) -> Optional[Tensor]: + if ( + dropout_rate == 0.0 + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): + return None + batch_size = x.shape[1] + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) + return mask + + def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: + """ + Apply sequence-level dropout to x. + x shape: (seq_len, batch_size, embed_dim) + """ + dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) + if dropout_mask is None: + return x + else: + return x * dropout_mask + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + time_emb: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + time_emb: the embedding representing the current timestep: shape (batch_size, embedding_dim) + or (seq_len, batch_size, embedding_dim) . + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # dropout rate for non-feedforward submodules + if torch.jit.is_scripting() or torch.jit.is_tracing(): + attention_skip_rate = 0.0 + else: + attention_skip_rate = ( + float(self.attention_skip_rate) if self.training else 0.0 + ) + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + if time_emb is not None: + + src = src + time_emb + + src = src + self.feed_forward1(src) + + self_attn_dropout_mask = self.get_sequence_dropout_mask( + src, attention_skip_rate + ) + + selected_attn_weights = attn_weights[0:1] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < float(self.const_attention_rate): + # Make attention weights constant. The intention is to + # encourage these modules to do something similar to an + # averaging-over-time operation. + # only need the mask, can just use the 1st one and expand later + selected_attn_weights = selected_attn_weights[0:1] + selected_attn_weights = (selected_attn_weights > 0.0).to( + selected_attn_weights.dtype + ) + selected_attn_weights = selected_attn_weights * ( + 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) + ) + + na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) + + src = src + ( + na if self_attn_dropout_mask is None else na * self_attn_dropout_mask + ) + + self_attn = self.self_attn1(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if self.use_conv: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + + if time_emb is not None: + src = src + time_emb + + src = src + self.sequence_dropout( + self.conv_module1( + src, + src_key_padding_mask=src_key_padding_mask, + ), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff2_skip_rate = 0.0 + else: + ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate + ) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn = self.self_attn2(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if self.use_conv: + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + + if time_emb is not None: + src = src + time_emb + + src = src + self.sequence_dropout( + self.conv_module2( + src, + src_key_padding_mask=src_key_padding_mask, + ), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff3_skip_rate = 0.0 + else: + ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate + ) + + src = self.balancer1(src) + src = self.norm(src) + + src = self.bypass(src_orig, src) + + src = self.balancer2(src) + src = self.whiten(src) + + return src + + +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + pos_dim: the dimension for the relative positional encoding + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + embed_dim: int, + time_embed_dim: int, + pos_dim: int, + warmup_begin: float, + warmup_end: float, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, + ) -> None: + super().__init__() + self.encoder_pos = CompactRelPositionalEncoding( + pos_dim, dropout_rate=0.15, length_factor=1.0 + ) + if time_embed_dim != -1: + self.time_emb = nn.Sequential( + SwooshR(), + nn.Linear(time_embed_dim, embed_dim), + ) + else: + self.time_emb = None + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert 0 <= warmup_begin <= warmup_end + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin # interpreted as a training batch index + for i in range(num_layers): + cur_end = cur_begin + delta + self.layers[i].bypass.skip_rate = ScheduledFloat( + (cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0, + ) + cur_begin = cur_end + + def forward( + self, + src: Tensor, + time_emb: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + the embedding representing the current timestep: shape (batch_size, embedding_dim) + or (seq_len, batch_size, embedding_dim) . + time_emb: the embedding representing the current timestep: shape (batch_size, embedding_dim) + or (seq_len, batch_size, embedding_dim) . + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + pos_emb = self.encoder_pos(src) + if self.time_emb is not None: + assert time_emb is not None + time_emb = self.time_emb(time_emb) + else: + assert time_emb is None + + output = src + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + time_emb=time_emb, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + return output + + +class BypassModule(nn.Module): + """ + An nn.Module that implements a learnable bypass scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + + def __init__( + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0, + ): + super().__init__() + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.skip_rate = copy.deepcopy(skip_rate) + self.straight_through_rate = copy.deepcopy(straight_through_rate) + self.scale_min = copy.deepcopy(scale_min) + self.scale_max = copy.deepcopy(scale_max) + + def _get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 corresponds to bypassing + # this module. + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value( + self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) + ) + skip_rate = float(self.skip_rate) + if skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + straight_through_rate = float(self.straight_through_rate) + if straight_through_rate != 0.0: + mask = ( + torch.rand((batch_size, 1), device=ans.device) + < straight_through_rate + ) + ans = torch.maximum(ans, mask.to(ans.dtype)) + return ans + + def forward(self, src_orig: Tensor, src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + bypass_scale = self._get_bypass_scale(src.shape[1]) + return src_orig + (src - src_orig) * bypass_scale + + +class DownsampledZipformer2Encoder(nn.Module): + r""" + DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + + def __init__(self, encoder: nn.Module, dim: int, downsample: int): + super(DownsampledZipformer2Encoder, self).__init__() + self.downsample_factor = downsample + self.downsample = SimpleDownsample(downsample) + self.num_layers = encoder.num_layers + self.encoder = encoder + self.upsample = SimpleUpsample(downsample) + self.out_combiner = BypassModule(dim, straight_through_rate=0) + + def forward( + self, + src: Tensor, + time_emb: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + time_emb: the embedding representing the current timestep: shape (batch_size, embedding_dim) + or (seq_len, batch_size, embedding_dim) . + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + src_orig = src + src = self.downsample(src) + ds = self.downsample_factor + if time_emb is not None and time_emb.dim() == 3: + time_emb = time_emb[::ds] + if attn_mask is not None: + attn_mask = attn_mask[::ds, ::ds] + if src_key_padding_mask is not None: + src_key_padding_mask = src_key_padding_mask[..., ::ds] + + src = self.encoder( + src, + time_emb=time_emb, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + +class SimpleDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum. + """ + + def __init__(self, downsample: int): + super(SimpleDownsample, self).__init__() + + self.bias = nn.Parameter(torch.zeros(downsample)) + + self.name = None # will be set from training code + + self.downsample = downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + + weights = self.bias.softmax(dim=0) + # weights: (downsample, 1, 1) + weights = weights.unsqueeze(-1).unsqueeze(-1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that just repeats the input. + """ + + def __init__(self, upsample: int): + super(SimpleUpsample, self).__init__() + self.upsample = upsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.upsample + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the Fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + + def __init__( + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0, embed_dim + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0, length_factor + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. + + Args: + x (Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ + self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) + pos_emb = self.pe[ + self.pe.size(0) // 2 + - x_size_left + + 1 : self.pe.size(0) // 2 # noqa E203 + + x.size(0), + :, + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + pos_dim: dimension of the positional encoding vectors, e.g. 128. + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_head_dim = pos_head_dim + self.dropout = dropout + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 + ) + + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # add a balancer for the keys that runs with very small probability, and + # tries to enforce that all dimensions have mean around zero. The + # weights produced by this module are invariant to adding a constant to + # the keys, so the derivative of the bias is mathematically zero; but + # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero + # bias because the small numerical roundoff tends to have a non-random + # sign. This module is intended to prevent that. Use a very small + # probability; that should be sufficient to fix the problem. + self.balance_keys = Balancer( + key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=0.6, + min_abs=0.0, + max_abs=100.0, + prob=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnostics only, see --print-diagnostics option + self.copy_pos_query = Identity() + self.copy_query = Identity() + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim, ( + p.shape[-1], + num_heads, + pos_head_dim, + ) + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + use_pos_scores = False + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # We can't put random.random() in the same line + use_pos_scores = True + elif not self.training or random.random() >= float(self.pos_emb_skip_rate): + use_pos_scores = True + + if use_pos_scores: + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, seq_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 50.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt( + attn_scores, limit=25.0, penalty=1.0e-04, name=self.name + ) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001 and not self.training: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.amp.autocast("cuda", enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.debug( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) + + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + +class FeedforwardModule(nn.Module): + """Feedforward module in TTSZipformer model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(embed_dim, feedforward_dim) + + self.hidden_balancer = Balancer( + feedforward_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation="SwooshL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.1, + ) + + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.hidden_balancer(x) + # out_proj contains SwooshL activation, then dropout, then linear. + x = self.out_proj(x) + x = self.out_whiten(x) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, + # because we noticed that well-trained instances of this module have abs-value before the sigmoid + # starting from about 3, and poorly-trained instances of the module have smaller abs values + # before the sigmoid. + self.balancer = Balancer( + hidden_channels, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), + max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), + min_abs=0.5, + max_abs=5.0, + ) + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) + + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + + s = self.balancer(s) + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + # after in_proj we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.balancer1 = Balancer( + bottleneck_dim, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), + max_positive=1.0, + min_abs=1.5, + max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), + ) + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + + self.balancer2 = Balancer( + bottleneck_dim, + channel_dim=1, + min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), + max_positive=1.0, + min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), + max_abs=10.0, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, + channels, + activation="SwooshR", + dropout_p=0.0, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.balancer1(s) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x = self.depthwise_conv(x) + + x = self.balancer2(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.whiten(x) # (time, batch, channels) + x = self.out_proj(x) # (time, batch, channels) + + return x diff --git a/egs/zipvoice/zipvoice/zipvoice_infer.py b/egs/zipvoice/zipvoice/zipvoice_infer.py new file mode 100644 index 000000000..472ad700d --- /dev/null +++ b/egs/zipvoice/zipvoice/zipvoice_infer.py @@ -0,0 +1,642 @@ +#!/usr/bin/env python3 +# Copyright 2025 Xiaomi Corp. (authors: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script generates speech with our pre-trained ZipVoice or + ZipVoice-Distill models. Required models will be automatically + downloaded from HuggingFace. + +Usage: + +Note: If you having trouble connecting to HuggingFace, + you try switch endpoint to mirror site: + +export HF_ENDPOINT=https://hf-mirror.com + +(1) Inference of a single sentence: + +python3 zipvoice/zipvoice_infer.py \ + --model-name "zipvoice_distill" \ + --prompt-wav prompt.wav \ + --prompt-text "I am a prompt." \ + --text "I am a sentence." \ + --res-wav-path result.wav + +(2) Inference of a list of sentences: +python3 zipvoice/zipvoice_infer.py \ + --model-name "zipvoice-distill" \ + --test-list test.tsv \ + --res-dir results + +`--model-name` can be `zipvoice` or `zipvoice_distill`, + which are the models before and after distillation, respectively. + +Each line of `test.tsv` is in the format of + `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`. +""" + +import argparse +import datetime as dt +import os + +import numpy as np +import safetensors.torch +import soundfile as sf +import torch +import torch.nn as nn +import torchaudio +from feature import TorchAudioFbank, TorchAudioFbankConfig +from huggingface_hub import hf_hub_download +from lhotse.utils import fix_random_seed +from model import get_distill_model, get_model +from tokenizer import TokenizerEmilia +from utils import AttributeDict +from vocos import Vocos + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-name", + type=str, + default="zipvoice_distill", + choices=["zipvoice", "zipvoice_distill"], + help="The model used for inference", + ) + + parser.add_argument( + "--test-list", + type=str, + default=None, + help="The list of prompt speech, prompt_transcription, " + "and text to synthesizein the format of " + "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.", + ) + + parser.add_argument( + "--prompt-wav", + type=str, + default=None, + help="The prompt wav to mimic", + ) + + parser.add_argument( + "--prompt-text", + type=str, + default=None, + help="The transcription of the prompt wav", + ) + + parser.add_argument( + "--text", + type=str, + default=None, + help="The text to synthesize", + ) + + parser.add_argument( + "--res-dir", + type=str, + default="results", + help="Path name of the generated wavs dir, " + "used when decdode-list is not None", + ) + + parser.add_argument( + "--res-wav-path", + type=str, + default="result.wav", + help="Path name of the generated wav path, " "used when decdode-list is None", + ) + + parser.add_argument( + "--guidance-scale", + type=float, + default=None, + help="The scale of classifier-free guidance during inference.", + ) + + parser.add_argument( + "--num-step", + type=int, + default=None, + help="The number of sampling steps.", + ) + + parser.add_argument( + "--feat-scale", + type=float, + default=0.1, + help="The scale factor of fbank feature", + ) + + parser.add_argument( + "--speed", + type=float, + default=1.0, + help="Control speech speed, 1.0 means normal, >1.0 means speed up", + ) + + parser.add_argument( + "--t-shift", + type=float, + default=0.5, + help="Shift t to smaller ones if t_shift < 1.0", + ) + + parser.add_argument( + "--target-rms", + type=float, + default=0.1, + help="Target speech normalization rms value", + ) + + parser.add_argument( + "--seed", + type=int, + default=666, + help="Random seed", + ) + + add_model_arguments(parser) + + return parser + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--fm-decoder-downsampling-factor", + type=str, + default="1,2,4,2,1", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--fm-decoder-num-layers", + type=str, + default="2,2,4,4,4", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--fm-decoder-cnn-module-kernel", + type=str, + default="31,15,7,15,31", + help="Sizes of convolutional kernels in convolution modules " + "in each encoder stack: a single int or comma-separated list.", + ) + + parser.add_argument( + "--fm-decoder-feedforward-dim", + type=int, + default=1536, + help="Feedforward dimension of the zipformer encoder layers, " + "per stack, comma separated.", + ) + + parser.add_argument( + "--fm-decoder-num-heads", + type=int, + default=4, + help="Number of attention heads in the zipformer encoder layers: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--fm-decoder-dim", + type=int, + default=512, + help="Embedding dimension in encoder stacks: a single int " + "or comma-separated list.", + ) + + parser.add_argument( + "--text-encoder-downsampling-factor", + type=str, + default="1", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--text-encoder-num-layers", + type=str, + default="4", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--text-encoder-feedforward-dim", + type=int, + default=512, + help="Feedforward dimension of the zipformer encoder layers, " + "per stack, comma separated.", + ) + + parser.add_argument( + "--text-encoder-cnn-module-kernel", + type=str, + default="9", + help="Sizes of convolutional kernels in convolution modules in " + "each encoder stack: a single int or comma-separated list.", + ) + + parser.add_argument( + "--text-encoder-num-heads", + type=int, + default=4, + help="Number of attention heads in the zipformer encoder layers: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--text-encoder-dim", + type=int, + default=192, + help="Embedding dimension in encoder stacks: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=int, + default=32, + help="Query/key dimension per head in encoder stacks: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=int, + default=12, + help="Value dimension per head in encoder stacks: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=int, + default=4, + help="Positional-encoding dimension per head in encoder stacks: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default=48, + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--time-embed-dim", + type=int, + default=192, + help="Embedding dimension of timestamps embedding.", + ) + + parser.add_argument( + "--text-embed-dim", + type=int, + default=192, + help="Embedding dimension of text embedding.", + ) + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "sampling_rate": 24000, + "frame_shift_ms": 256 / 24000 * 1000, + "feat_dim": 100, + } + ) + + return params + + +def get_vocoder(): + vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz") + return vocoder + + +def generate_sentence( + save_path: str, + prompt_text: str, + prompt_wav: str, + text: str, + model: nn.Module, + vocoder: nn.Module, + tokenizer: TokenizerEmilia, + feature_extractor: TorchAudioFbank, + device: torch.device, + num_step: int = 16, + guidance_scale: float = 1.0, + speed: float = 1.0, + t_shift: float = 0.5, + target_rms: float = 0.1, + feat_scale: float = 0.1, + sampling_rate: int = 24000, +): + """ + Generate waveform of a text based on a given prompt + waveform and its transcription. + + Args: + save_path (str): Path to save the generated wav. + prompt_text (str): Transcription of the prompt wav. + prompt_wav (str): Path to the prompt wav file. + text (str): Text to be synthesized into a waveform. + model (nn.Module): The model used for generation. + vocoder (nn.Module): The vocoder used to convert features to waveforms. + tokenizer (TokenizerEmilia): The tokenizer used to convert text to tokens. + feature_extractor (TorchAudioFbank): The feature extractor used to + extract acoustic features. + device (torch.device): The device on which computations are performed. + num_step (int, optional): Number of steps for decoding. Defaults to 16. + guidance_scale (float, optional): Scale for classifier-free guidance. + Defaults to 1.0. + speed (float, optional): Speed control. Defaults to 1.0. + t_shift (float, optional): Time shift. Defaults to 0.5. + target_rms (float, optional): Target RMS for waveform normalization. + Defaults to 0.1. + feat_scale (float, optional): Scale for features. + Defaults to 0.1. + sampling_rate (int, optional): Sampling rate for the waveform. + Defaults to 24000. + Returns: + metrics (dict): Dictionary containing time and real-time + factor metrics for processing. + """ + # Convert text to tokens + tokens = tokenizer.texts_to_token_ids([text]) + prompt_tokens = tokenizer.texts_to_token_ids([prompt_text]) + + # Load and preprocess prompt wav + prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav) + prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav))) + if prompt_rms < target_rms: + prompt_wav = prompt_wav * target_rms / prompt_rms + + if prompt_sampling_rate != sampling_rate: + resampler = torchaudio.transforms.Resample( + orig_freq=prompt_sampling_rate, new_freq=sampling_rate + ) + prompt_wav = resampler(prompt_wav) + + # Extract features from prompt wav + prompt_features = feature_extractor.extract( + prompt_wav, sampling_rate=sampling_rate + ).to(device) + prompt_features = prompt_features.unsqueeze(0) * feat_scale + prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device) + + # Start timing + start_t = dt.datetime.now() + + # Generate features + ( + pred_features, + pred_features_lens, + pred_prompt_features, + pred_prompt_features_lens, + ) = model.sample( + tokens=tokens, + prompt_tokens=prompt_tokens, + prompt_features=prompt_features, + prompt_features_lens=prompt_features_lens, + speed=speed, + t_shift=t_shift, + duration="predict", + num_step=num_step, + guidance_scale=guidance_scale, + ) + + # Postprocess predicted features + pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T) + + # Start vocoder processing + start_vocoder_t = dt.datetime.now() + wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1) + + # Calculate processing times and real-time factors + t = (dt.datetime.now() - start_t).total_seconds() + t_no_vocoder = (start_vocoder_t - start_t).total_seconds() + t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds() + wav_seconds = wav.shape[-1] / sampling_rate + rtf = t / wav_seconds + rtf_no_vocoder = t_no_vocoder / wav_seconds + rtf_vocoder = t_vocoder / wav_seconds + metrics = { + "t": t, + "t_no_vocoder": t_no_vocoder, + "t_vocoder": t_vocoder, + "wav_seconds": wav_seconds, + "rtf": rtf, + "rtf_no_vocoder": rtf_no_vocoder, + "rtf_vocoder": rtf_vocoder, + } + + # Adjust wav volume if necessary + if prompt_rms < target_rms: + wav = wav * prompt_rms / target_rms + wav = wav[0].cpu().numpy() + sf.write(save_path, wav, sampling_rate) + + return metrics + + +def generate( + res_dir: str, + test_list: str, + model: nn.Module, + vocoder: nn.Module, + tokenizer: TokenizerEmilia, + feature_extractor: TorchAudioFbank, + device: torch.device, + num_step: int = 16, + guidance_scale: float = 1.0, + speed: float = 1.0, + t_shift: float = 0.5, + target_rms: float = 0.1, + feat_scale: float = 0.1, + sampling_rate: int = 24000, +): + total_t = [] + total_t_no_vocoder = [] + total_t_vocoder = [] + total_wav_seconds = [] + + with open(test_list, "r") as fr: + lines = fr.readlines() + + for i, line in enumerate(lines): + wav_name, prompt_text, prompt_wav, text = line.strip().split("\t") + save_path = f"{res_dir}/{wav_name}.wav" + metrics = generate_sentence( + save_path=save_path, + prompt_text=prompt_text, + prompt_wav=prompt_wav, + text=text, + model=model, + vocoder=vocoder, + tokenizer=tokenizer, + feature_extractor=feature_extractor, + device=device, + num_step=num_step, + guidance_scale=guidance_scale, + speed=speed, + t_shift=t_shift, + target_rms=target_rms, + feat_scale=feat_scale, + sampling_rate=sampling_rate, + ) + print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}") + total_t.append(metrics["t"]) + total_t_no_vocoder.append(metrics["t_no_vocoder"]) + total_t_vocoder.append(metrics["t_vocoder"]) + total_wav_seconds.append(metrics["wav_seconds"]) + + print(f"Average RTF: {np.sum(total_t)/np.sum(total_wav_seconds):.4f}") + print( + f"Average RTF w/o vocoder: " + f"{np.sum(total_t_no_vocoder)/np.sum(total_wav_seconds):.4f}" + ) + print( + f"Average RTF vocoder: " + f"{np.sum(total_t_vocoder)/np.sum(total_wav_seconds):.4f}" + ) + + +@torch.inference_mode() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + model_defaults = { + "zipvoice": { + "num_step": 16, + "guidance_scale": 1.0, + }, + "zipvoice_distill": { + "num_step": 8, + "guidance_scale": 3.0, + }, + } + + model_specific_defaults = model_defaults.get(params.model_name, {}) + + for param, value in model_specific_defaults.items(): + if getattr(params, param) == parser.get_default(param): + setattr(params, param, value) + print(f"Setting {param} to default value: {value}") + + assert (params.test_list is not None) ^ ( + (params.prompt_wav and params.prompt_text and params.text) is not None + ), ( + "For inference, please provide prompts and text with either '--test-list'" + " or '--prompt-wav, --prompt-text and --text'." + ) + + if torch.cuda.is_available(): + params.device = torch.device("cuda", 0) + else: + params.device = torch.device("cpu") + + token_file = hf_hub_download("zhu-han/ZipVoice", filename="tokens_emilia.txt") + + tokenizer = TokenizerEmilia(token_file) + + params.vocab_size = tokenizer.vocab_size + params.pad_id = tokenizer.pad_id + fix_random_seed(params.seed) + + if params.model_name == "zipvoice_distill": + model = get_distill_model(params) + model_ckpt = hf_hub_download( + "zhu-han/ZipVoice", filename="exp_zipvoice_distill/model.safetensors" + ) + else: + model = get_model(params) + model_ckpt = hf_hub_download( + "zhu-han/ZipVoice", filename="exp_zipvoice/model.safetensors" + ) + + safetensors.torch.load_model(model, model_ckpt) + + model = model.to(params.device) + model.eval() + + vocoder = get_vocoder() + vocoder = vocoder.to(params.device) + vocoder.eval() + + config = TorchAudioFbankConfig( + sampling_rate=params.sampling_rate, + n_mels=100, + n_fft=1024, + hop_length=256, + ) + feature_extractor = TorchAudioFbank(config) + + if params.test_list: + os.makedirs(params.res_dir, exist_ok=True) + generate( + res_dir=params.res_dir, + test_list=params.test_list, + model=model, + vocoder=vocoder, + tokenizer=tokenizer, + feature_extractor=feature_extractor, + device=params.device, + num_step=params.num_step, + guidance_scale=params.guidance_scale, + speed=params.speed, + t_shift=params.t_shift, + target_rms=params.target_rms, + feat_scale=params.feat_scale, + sampling_rate=params.sampling_rate, + ) + else: + generate_sentence( + save_path=params.res_wav_path, + prompt_text=params.prompt_text, + prompt_wav=params.prompt_wav, + text=params.text, + model=model, + vocoder=vocoder, + tokenizer=tokenizer, + feature_extractor=feature_extractor, + device=params.device, + num_step=params.num_step, + guidance_scale=params.guidance_scale, + speed=params.speed, + t_shift=params.t_shift, + target_rms=params.target_rms, + feat_scale=params.feat_scale, + sampling_rate=params.sampling_rate, + ) + print("Done") + + +if __name__ == "__main__": + main()