diff --git a/egs/zipvoice/README.md b/egs/zipvoice/README.md deleted file mode 100644 index 0c97d7ed8..000000000 --- a/egs/zipvoice/README.md +++ /dev/null @@ -1,412 +0,0 @@ -## ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching - - -[![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](http://arxiv.org/abs/2506.13053) -[![demo](https://img.shields.io/badge/GitHub-Demo%20page-orange.svg)](https://zipvoice.github.io/) - - -## Overview -ZipVoice is a high-quality zero-shot TTS model with a small model size and fast inference speed. -#### Key features: - -- Small and fast: only 123M parameters. - -- High-quality: state-of-the-art voice cloning performance in speaker similarity, intelligibility, and naturalness. - -- Multi-lingual: support Chinese and English. - - -## News -**2025/06/16**: 🔥 ZipVoice is released. - - -## Installation - -* Clone icefall repository and change to zipvoice directory: - -```bash -git clone https://github.com/k2-fsa/icefall.git -cd icefall/egs/zipvoice -``` - -* Create a Python virtual environment (optional but recommended): - -```bash -python3 -m venv venv -source venv/bin/activate -``` - -* Install the required packages: - -```bash -pip install -r requirements.txt -``` - -## Usage - -To generate speech with our pre-trained ZipVoice or ZipVoice-Distill models, use the following commands (Required models will be downloaded from HuggingFace): - -### 1. Inference of a single sentence: -```bash -python3 zipvoice/zipvoice_infer.py \ - --model-name "zipvoice_distill" \ - --prompt-wav prompt.wav \ - --prompt-text "I am the transcription of the prompt wav." \ - --text "I am the text to be synthesized." \ - --res-wav-path result.wav - -# Example with a pre-defined prompt wav and text -python3 zipvoice/zipvoice_infer.py \ - --model-name "zipvoice_distill" \ - --prompt-wav assets/prompt-en.wav \ - --prompt-text "Some call me nature, others call me mother nature. I've been here for over four point five billion years, twenty two thousand five hundred times longer than you." \ - --text "Welcome to use our tts model, have fun!" \ - --res-wav-path result.wav -``` - -### 2. Inference of a list of sentences: -```bash -python3 zipvoice/zipvoice_infer.py \ - --model-name "zipvoice_distill" \ - --test-list test.tsv \ - --res-dir results/test -``` - -- `--model-name` can be `zipvoice` or `zipvoice_distill`, which are models before and after distillation, respectively. -- Each line of `test.tsv` is in the format of `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`. - - -> **Note:** If you having trouble connecting to HuggingFace, try: -```bash -export HF_ENDPOINT=https://hf-mirror.com -``` - -## Training Your Own Model - -The following steps show how to train a model from scratch on Emilia and LibriTTS datasets, respectively. - -### 0. Install dependencies for training - -```bash -# Install pytorch and k2. -# If you want to use different versions, please refer to https://k2-fsa.org/get-started/k2/ for details. -# For users in China mainland, please refer to https://k2-fsa.org/zh-CN/get-started/k2/ - -# Note: Make sure you have installed the correct version of PyTorch and k2 that matches your CUDA version. -# For example, if want to use pytorch 2.5.1 and you are using CUDA 12.1, you can install PyTorch and k2 as follows: - -pip install torch==2.5.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121 -pip install k2==1.24.4.dev20250208+cuda12.1.torch2.5.1 -f https://k2-fsa.github.io/k2/cuda.html - -pip install -r ../../requirements.txt -``` - -### 1. Data Preparation - -#### 1.1. Prepare the Emilia dataset - -```bash -bash scripts/prepare_emilia.sh -``` - -See [scripts/prepare_emilia.sh](scripts/prepare_emilia.sh) for step by step instructions. - -#### 1.2 Prepare the LibriTTS dataset - -```bash -bash scripts/prepare_libritts.sh -``` - -See [scripts/prepare_libritts.sh](scripts/prepare_libritts.sh) for step by step instructions. - -### 2. Training - -#### 2.1 Traininig on Emilia - -
-Expand to view training steps - -##### 2.1.1 Train the ZipVoice model - -- Training: - -```bash -export PYTHONPATH=../../:$PYTHONPATH -python3 zipvoice/train_flow.py \ - --world-size 8 \ - --use-fp16 1 \ - --dataset emilia \ - --max-duration 500 \ - --lr-hours 30000 \ - --lr-batches 7500 \ - --token-file "data/tokens_emilia.txt" \ - --manifest-dir "data/fbank" \ - --num-epochs 11 \ - --exp-dir zipvoice/exp_zipvoice -``` - -- Average the checkpoints to produce the final model: - -```bash -export PYTHONPATH=../../:$PYTHONPATH -python3 zipvoice/generate_averaged_model.py \ - --epoch 11 \ - --avg 4 \ - --distill 0 \ - --token-file data/tokens_emilia.txt \ - --dataset "emilia" \ - --exp-dir ./zipvoice/exp_zipvoice -# The generated model is zipvoice/exp_zipvoice/epoch-11-avg-4.pt -``` - -##### 2.1.2. Train the ZipVoice-Distill model (Optional) - -- The first-stage distillation: - -```bash -export PYTHONPATH=../../:$PYTHONPATH -python3 zipvoice/train_distill.py \ - --world-size 8 \ - --use-fp16 1 \ - --tensorboard 1 \ - --dataset "emilia" \ - --base-lr 0.0005 \ - --max-duration 500 \ - --token-file "data/tokens_emilia.txt" \ - --manifest-dir "data/fbank" \ - --teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \ - --num-updates 60000 \ - --distill-stage "first" \ - --exp-dir zipvoice/exp_zipvoice_distill_1stage -``` - -- Average checkpoints for the second-stage initialization: - -```bash -export PYTHONPATH=../../:$PYTHONPATH -python3 zipvoice/generate_averaged_model.py \ - --iter 60000 \ - --avg 7 \ - --distill 1 \ - --token-file data/tokens_emilia.txt \ - --dataset "emilia" \ - --exp-dir ./zipvoice/exp_zipvoice_distill_1stage -# The generated model is zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt -``` - -- The second-stage distillation: - -```bash -export PYTHONPATH=../../:$PYTHONPATH -python3 zipvoice/train_distill.py \ - --world-size 8 \ - --use-fp16 1 \ - --tensorboard 1 \ - --dataset "emilia" \ - --base-lr 0.0001 \ - --max-duration 200 \ - --token-file "data/tokens_emilia.txt" \ - --manifest-dir "data/fbank" \ - --teacher-model zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt \ - --num-updates 2000 \ - --distill-stage "second" \ - --exp-dir zipvoice/exp_zipvoice_distill_new -``` -
- - -#### 2.2 Traininig on LibriTTS - -
-Expand to view training steps - -##### 2.2.1 Train the ZipVoice model - -- Training: - -```bash -export PYTHONPATH=../../:$PYTHONPATH -python3 zipvoice/train_flow.py \ - --world-size 8 \ - --use-fp16 1 \ - --dataset libritts \ - --max-duration 250 \ - --lr-epochs 10 \ - --lr-batches 7500 \ - --token-file "data/tokens_libritts.txt" \ - --manifest-dir "data/fbank" \ - --num-epochs 60 \ - --exp-dir zipvoice/exp_zipvoice_libritts -``` - -- Average the checkpoints to produce the final model: - -```bash -export PYTHONPATH=../../:$PYTHONPATH -python3 zipvoice/generate_averaged_model.py \ - --epoch 60 \ - --avg 10 \ - --distill 0 \ - --token-file data/tokens_libritts.txt \ - --dataset "libritts" \ - --exp-dir ./zipvoice/exp_zipvoice_libritts -# The generated model is zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt -``` - -##### 2.1.2 Train the ZipVoice-Distill model (Optional) - -- The first-stage distillation: - -```bash -export PYTHONPATH=../../:$PYTHONPATH -python3 zipvoice/train_distill.py \ - --world-size 8 \ - --use-fp16 1 \ - --tensorboard 1 \ - --dataset "libritts" \ - --base-lr 0.001 \ - --max-duration 250 \ - --token-file "data/tokens_libritts.txt" \ - --manifest-dir "data/fbank" \ - --teacher-model zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt \ - --num-epochs 6 \ - --distill-stage "first" \ - --exp-dir zipvoice/exp_zipvoice_distill_1stage_libritts -``` - -- Average checkpoints for the second-stage initialization: - -```bash -export PYTHONPATH=../../:$PYTHONPATH -python3 ./zipvoice/generate_averaged_model.py \ - --epoch 6 \ - --avg 3 \ - --distill 1 \ - --token-file data/tokens_libritts.txt \ - --dataset "libritts" \ - --exp-dir ./zipvoice/exp_zipvoice_distill_1stage_libritts -# The generated model is zipvoice/exp_zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt -``` - -- The second-stage distillation: - -```bash -export PYTHONPATH=../../:$PYTHONPATH -python3 zipvoice/train_distill.py \ - --world-size 8 \ - --use-fp16 1 \ - --tensorboard 1 \ - --dataset "libritts" \ - --base-lr 0.001 \ - --max-duration 250 \ - --token-file "data/tokens_libritts.txt" \ - --manifest-dir "data/fbank" \ - --teacher-model zipvoice/exp_zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt \ - --num-epochs 6 \ - --distill-stage "second" \ - --exp-dir zipvoice/exp_zipvoice_distill_libritts -``` - -- Average checkpoints to produce the final model: - -```bash -export PYTHONPATH=../../:$PYTHONPATH -python3 ./zipvoice/generate_averaged_model.py \ - --epoch 6 \ - --avg 3 \ - --distill 1 \ - --token-file data/tokens_libritts.txt \ - --dataset "libritts" \ - --exp-dir ./zipvoice/exp_zipvoice_distill_libritts -# The generated model is ./zipvoice/exp_zipvoice_distill_libritts/epoch-6-avg-3.pt -``` -
- - -### 3. Inference with the trained model - -#### 3.1 Inference with the model trained on Emilia -
-Expand to view inference commands. - -##### 3.1.1 ZipVoice model before distill: -```bash -export PYTHONPATH=../../:$PYTHONPATH -python3 zipvoice/infer.py \ - --checkpoint zipvoice/exp_zipvoice/epoch-11-avg-4.pt \ - --distill 0 \ - --token-file "data/tokens_emilia.txt" \ - --test-list test.tsv \ - --res-dir results/test \ - --num-step 16 \ - --guidance-scale 1 -``` - -##### 3.1.2 ZipVoice-Distill model before distill: -```bash -export PYTHONPATH=../../:$PYTHONPATH -python3 zipvoice/infer.py \ - --checkpoint zipvoice/exp_zipvoice_distill/checkpoint-2000.pt \ - --distill 1 \ - --token-file "data/tokens_emilia.txt" \ - --test-list test.tsv \ - --res-dir results/test_distill \ - --num-step 8 \ - --guidance-scale 3 -``` -
- - -#### 3.2 Inference with the model trained on LibriTTS - -
-Expand to view inference commands. - -##### 3.2.1 ZipVoice model before distill: -```bash -export PYTHONPATH=../../:$PYTHONPATH -python3 zipvoice/infer.py \ - --checkpoint zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt \ - --distill 0 \ - --token-file "data/tokens_libritts.txt" \ - --test-list test.tsv \ - --res-dir results/test_libritts \ - --num-step 8 \ - --guidance-scale 1 \ - --target-rms 1.0 \ - --t-shift 0.7 -``` - -##### 3.2.2 ZipVoice-Distill model before distill - -```bash -export PYTHONPATH=../../:$PYTHONPATH -python3 zipvoice/infer.py \ - --checkpoint zipvoice/exp_zipvoice_distill/epoch-6-avg-3.pt \ - --distill 1 \ - --token-file "data/tokens_libritts.txt" \ - --test-list test.tsv \ - --res-dir results/test_distill_libritts \ - --num-step 4 \ - --guidance-scale 3 \ - --target-rms 1.0 \ - --t-shift 0.7 -``` -
- -### 4. Evaluation on benchmarks - -See [local/evaluate.sh](local/evaluate.sh) for details of objective metrics evaluation -on three test sets, i.e., LibriSpeech-PC test-clean, Seed-TTS test-en and Seed-TTS test-zh. - - -## Citation - -```bibtex -@article{zhu-2025-zipvoice, - title={ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching}, - author={Han Zhu and Wei Kang and Zengwei Yao and Liyong Guo and Fangjun Kuang and Zhaoqing Li and Weiji Zhuang and Long Lin and Daniel Povey} - journal={arXiv preprint arXiv:2506.13053}, - year={2025}, -} -``` diff --git a/egs/zipvoice/assets/prompt-en.wav b/egs/zipvoice/assets/prompt-en.wav deleted file mode 100644 index b7047ce9b..000000000 Binary files a/egs/zipvoice/assets/prompt-en.wav and /dev/null differ diff --git a/egs/zipvoice/local/compute_fbank.py b/egs/zipvoice/local/compute_fbank.py deleted file mode 100644 index 0c440b995..000000000 --- a/egs/zipvoice/local/compute_fbank.py +++ /dev/null @@ -1,287 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import os -from concurrent.futures import ProcessPoolExecutor as Pool -from pathlib import Path -from typing import Optional - -import lhotse -import torch -from feature import TorchAudioFbank, TorchAudioFbankConfig -from lhotse import ( - CutSet, - LilcomChunkyWriter, - load_manifest_lazy, - set_audio_duration_mismatch_tolerance, -) - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def str2bool(v): - """Used in argparse.ArgumentParser.add_argument to indicate - that a type is a bool type and user can enter - - - yes, true, t, y, 1, to represent True - - no, false, f, n, 0, to represent False - - See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa - """ - if isinstance(v, bool): - return v - if v.lower() in ("yes", "true", "t", "y", "1"): - return True - elif v.lower() in ("no", "false", "f", "n", "0"): - return False - else: - raise argparse.ArgumentTypeError("Boolean value expected.") - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--sampling-rate", - type=int, - default=24000, - help="The target sampling rate, the audio will be resampled to this sampling_rate.", - ) - - parser.add_argument( - "--frame-shift", - type=int, - default=256, - help="Frame shift in samples", - ) - - parser.add_argument( - "--frame-length", - type=int, - default=1024, - help="Frame length in samples", - ) - - parser.add_argument( - "--num-mel-bins", - type=int, - default=100, - help="The num of mel filters.", - ) - - parser.add_argument( - "--dataset", - type=str, - help="Dataset name.", - ) - - parser.add_argument( - "--subset", - type=str, - help="The subset of the dataset.", - ) - - parser.add_argument( - "--source-dir", - type=str, - default="data/manifests", - help="The source directory of manifest files.", - ) - - parser.add_argument( - "--dest-dir", - type=str, - default="data/fbank", - help="The destination directory of manifest files.", - ) - - parser.add_argument( - "--split-cuts", - type=str2bool, - default=False, - help="Whether to use splited cuts.", - ) - - parser.add_argument( - "--split-begin", - type=int, - help="Start idx of splited cuts.", - ) - - parser.add_argument( - "--split-end", - type=int, - help="End idx of splited cuts.", - ) - - parser.add_argument( - "--batch-duration", - type=int, - default=1000, - help="The batch duration when computing the features.", - ) - - parser.add_argument( - "--num-jobs", type=int, default=20, help="The number of extractor workers." - ) - - return parser.parse_args() - - -def compute_fbank_split_single(params, idx): - lhotse.set_audio_duration_mismatch_tolerance(0.1) # for emilia - src_dir = Path(params.source_dir) - output_dir = Path(params.dest_dir) - num_mel_bins = params.num_mel_bins - - if not src_dir.exists(): - logging.error(f"{src_dir} not exists") - return - - if not output_dir.exists(): - output_dir.mkdir(parents=True, exist_ok=True) - - num_digits = 8 - - config = TorchAudioFbankConfig( - sampling_rate=params.sampling_rate, - n_mels=params.num_mel_bins, - n_fft=params.frame_length, - hop_length=params.frame_shift, - ) - extractor = TorchAudioFbank(config) - - prefix = params.dataset - subset = params.subset - suffix = "jsonl.gz" - - idx = f"{idx}".zfill(num_digits) - cuts_filename = f"{prefix}_cuts_{subset}.{idx}.{suffix}" - - if (src_dir / cuts_filename).is_file(): - logging.info(f"Loading manifests {src_dir / cuts_filename}") - cut_set = load_manifest_lazy(src_dir / cuts_filename) - else: - logging.warning(f"Raw {cuts_filename} not exists, skipping") - return - - cut_set = cut_set.resample(params.sampling_rate) - - if (output_dir / cuts_filename).is_file(): - logging.info(f"{cuts_filename} already exists - skipping.") - return - - logging.info(f"Processing {subset}.{idx} of {prefix}") - - cut_set = cut_set.compute_and_store_features_batch( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{subset}_{idx}", - num_workers=4, - batch_duration=params.batch_duration, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - cut_set.to_file(output_dir / cuts_filename) - - -def compute_fbank_split(params): - if params.split_end < params.split_begin: - logging.warning( - f"Split begin should be smaller than split end, given " - f"{params.split_begin} -> {params.split_end}." - ) - - with Pool(max_workers=params.num_jobs) as pool: - futures = [ - pool.submit(compute_fbank_split_single, params, i) - for i in range(params.split_begin, params.split_end) - ] - for f in futures: - f.result() - f.done() - - -def compute_fbank(params): - src_dir = Path(params.source_dir) - output_dir = Path(params.dest_dir) - num_jobs = params.num_jobs - num_mel_bins = params.num_mel_bins - - prefix = params.dataset - subset = params.subset - suffix = "jsonl.gz" - - cut_set_name = f"{prefix}_cuts_{subset}.{suffix}" - - if (src_dir / cut_set_name).is_file(): - logging.info(f"Loading manifests {src_dir / cut_set_name}") - cut_set = load_manifest_lazy(src_dir / cut_set_name) - else: - recordings = load_manifest_lazy( - src_dir / f"{prefix}_recordings_{subset}.{suffix}" - ) - supervisions = load_manifest_lazy( - src_dir / f"{prefix}_supervisions_{subset}.{suffix}" - ) - cut_set = CutSet.from_manifests( - recordings=recordings, - supervisions=supervisions, - ) - - cut_set = cut_set.resample(params.sampling_rate) - - config = TorchAudioFbankConfig( - sampling_rate=params.sampling_rate, - n_mels=params.num_mel_bins, - n_fft=params.frame_length, - hop_length=params.frame_shift, - ) - extractor = TorchAudioFbank(config) - - cuts_filename = f"{prefix}_cuts_{subset}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{prefix} {subset} already exists - skipping.") - return - logging.info(f"Processing {subset} of {prefix}") - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{subset}", - num_jobs=num_jobs, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / cuts_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - logging.info(vars(args)) - if args.split_cuts: - compute_fbank_split(params=args) - else: - compute_fbank(params=args) diff --git a/egs/zipvoice/local/evaluate_sim.py b/egs/zipvoice/local/evaluate_sim.py deleted file mode 100644 index df439cf2c..000000000 --- a/egs/zipvoice/local/evaluate_sim.py +++ /dev/null @@ -1,508 +0,0 @@ -""" -Calculate pairwise Speaker Similarity betweeen two speech directories. -SV model wavlm_large_finetune.pth is downloaded from - https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification -SSL model wavlm_large.pt is downloaded from - https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt -""" -import argparse -import logging -import os - -import librosa -import numpy as np -import soundfile as sf -import torch -import torch.nn as nn -import torch.nn.functional as F -from tqdm import tqdm - -logging.basicConfig(level=logging.INFO) - - -def get_parser(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--eval-path", type=str, help="path of the evaluated speech directory" - ) - parser.add_argument( - "--test-list", - type=str, - help="path of the file list that contains the corresponding " - "relationship between the prompt and evaluated speech. " - "The first column is the wav name and the third column is the prompt speech", - ) - parser.add_argument( - "--sv-model-path", - type=str, - default="model/UniSpeech/wavlm_large_finetune.pth", - help="path of the wavlm-based ECAPA-TDNN model", - ) - parser.add_argument( - "--ssl-model-path", - type=str, - default="model/s3prl/wavlm_large.pt", - help="path of the wavlm SSL model", - ) - return parser - - -class SpeakerSimilarity: - def __init__( - self, - sv_model_path="model/UniSpeech/wavlm_large_finetune.pth", - ssl_model_path="model/s3prl/wavlm_large.pt", - ): - """ - Initialize - """ - self.sample_rate = 16000 - self.channels = 1 - self.device = ( - torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - ) - logging.info("[Speaker Similarity] Using device: {}".format(self.device)) - self.model = ECAPA_TDNN_WAVLLM( - feat_dim=1024, - channels=512, - emb_dim=256, - sr=16000, - ssl_model_path=ssl_model_path, - ) - state_dict = torch.load( - sv_model_path, map_location=lambda storage, loc: storage - ) - self.model.load_state_dict(state_dict["model"], strict=False) - self.model.to(self.device) - self.model.eval() - - def get_embeddings(self, wav_list, dtype="float32"): - """ - Get embeddings - """ - - def _load_speech_task(fname, sample_rate): - - wav_data, sr = sf.read(fname, dtype=dtype) - if sr != sample_rate: - wav_data = librosa.resample( - wav_data, orig_sr=sr, target_sr=self.sample_rate - ) - wav_data = torch.from_numpy(wav_data) - - return wav_data - - embd_lst = [] - for file_path in tqdm(wav_list): - speech = _load_speech_task(file_path, self.sample_rate) - speech = speech.to(self.device) - with torch.no_grad(): - embd = self.model([speech]) - embd_lst.append(embd) - - return embd_lst - - def score( - self, - eval_path, - test_list, - dtype="float32", - ): - """ - Computes the Speaker Similarity (SIM-o) between two directories of speech files. - - Parameters: - - eval_path (str): Path to the directory containing evaluation speech files. - - test_list (str): Path to the file containing the corresponding relationship - between prompt and evaluated speech. - - dtype (str, optional): Data type for loading speech. Default is "float32". - - Returns: - - float: The Speaker Similarity (SIM-o) score between the two directories - of speech files. - """ - prompt_wavs = [] - eval_wavs = [] - with open(test_list, "r") as fr: - lines = fr.readlines() - for line in lines: - wav_name, prompt_text, prompt_wav, text = line.strip().split("\t") - prompt_wavs.append(prompt_wav) - eval_wavs.append(os.path.join(eval_path, wav_name + ".wav")) - embds_prompt = self.get_embeddings(prompt_wavs, dtype=dtype) - - embds_eval = self.get_embeddings(eval_wavs, dtype=dtype) - - # Check if embeddings are empty - if len(embds_prompt) == 0: - logging.info("[Speaker Similarity] real set dir is empty, exiting...") - return -1 - if len(embds_eval) == 0: - logging.info("[Speaker Similarity] eval set dir is empty, exiting...") - return -1 - - scores = [] - for real_embd, eval_embd in zip(embds_prompt, embds_eval): - scores.append( - torch.nn.functional.cosine_similarity(real_embd, eval_embd, dim=-1) - .detach() - .cpu() - .numpy() - ) - - return np.mean(scores) - - -# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN - -""" Res2Conv1d + BatchNorm1d + ReLU -""" - - -class Res2Conv1dReluBn(nn.Module): - """ - in_channels == out_channels == channels - """ - - def __init__( - self, - channels, - kernel_size=1, - stride=1, - padding=0, - dilation=1, - bias=True, - scale=4, - ): - super().__init__() - assert channels % scale == 0, "{} % {} != 0".format(channels, scale) - self.scale = scale - self.width = channels // scale - self.nums = scale if scale == 1 else scale - 1 - - self.convs = [] - self.bns = [] - for i in range(self.nums): - self.convs.append( - nn.Conv1d( - self.width, - self.width, - kernel_size, - stride, - padding, - dilation, - bias=bias, - ) - ) - self.bns.append(nn.BatchNorm1d(self.width)) - self.convs = nn.ModuleList(self.convs) - self.bns = nn.ModuleList(self.bns) - - def forward(self, x): - out = [] - spx = torch.split(x, self.width, 1) - for i in range(self.nums): - if i == 0: - sp = spx[i] - else: - sp = sp + spx[i] - # Order: conv -> relu -> bn - sp = self.convs[i](sp) - sp = self.bns[i](F.relu(sp)) - out.append(sp) - if self.scale != 1: - out.append(spx[self.nums]) - out = torch.cat(out, dim=1) - - return out - - -""" Conv1d + BatchNorm1d + ReLU -""" - - -class Conv1dReluBn(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0, - dilation=1, - bias=True, - ): - super().__init__() - self.conv = nn.Conv1d( - in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias - ) - self.bn = nn.BatchNorm1d(out_channels) - - def forward(self, x): - return self.bn(F.relu(self.conv(x))) - - -""" The SE connection of 1D case. -""" - - -class SE_Connect(nn.Module): - def __init__(self, channels, se_bottleneck_dim=128): - super().__init__() - self.linear1 = nn.Linear(channels, se_bottleneck_dim) - self.linear2 = nn.Linear(se_bottleneck_dim, channels) - - def forward(self, x): - out = x.mean(dim=2) - out = F.relu(self.linear1(out)) - out = torch.sigmoid(self.linear2(out)) - out = x * out.unsqueeze(2) - - return out - - -""" SE-Res2Block of the ECAPA-TDNN architecture. -""" - - -# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale): -# return nn.Sequential( -# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0), -# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale), -# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0), -# SE_Connect(channels) -# ) - - -class SE_Res2Block(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - scale, - se_bottleneck_dim, - ): - super().__init__() - self.Conv1dReluBn1 = Conv1dReluBn( - in_channels, out_channels, kernel_size=1, stride=1, padding=0 - ) - self.Res2Conv1dReluBn = Res2Conv1dReluBn( - out_channels, kernel_size, stride, padding, dilation, scale=scale - ) - self.Conv1dReluBn2 = Conv1dReluBn( - out_channels, out_channels, kernel_size=1, stride=1, padding=0 - ) - self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim) - - self.shortcut = None - if in_channels != out_channels: - self.shortcut = nn.Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - ) - - def forward(self, x): - residual = x - if self.shortcut: - residual = self.shortcut(x) - - x = self.Conv1dReluBn1(x) - x = self.Res2Conv1dReluBn(x) - x = self.Conv1dReluBn2(x) - x = self.SE_Connect(x) - - return x + residual - - -""" Attentive weighted mean and standard deviation pooling. -""" - - -class AttentiveStatsPool(nn.Module): - def __init__(self, in_dim, attention_channels=128, global_context_att=False): - super().__init__() - self.global_context_att = global_context_att - - # Use Conv1d with stride == 1 rather than Linear, - # then we don't need to transpose inputs. - if global_context_att: - self.linear1 = nn.Conv1d( - in_dim * 3, attention_channels, kernel_size=1 - ) # equals W and b in the paper - else: - self.linear1 = nn.Conv1d( - in_dim, attention_channels, kernel_size=1 - ) # equals W and b in the paper - self.linear2 = nn.Conv1d( - attention_channels, in_dim, kernel_size=1 - ) # equals V and k in the paper - - def forward(self, x): - - if self.global_context_att: - context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) - context_std = torch.sqrt( - torch.var(x, dim=-1, keepdim=True) + 1e-10 - ).expand_as(x) - x_in = torch.cat((x, context_mean, context_std), dim=1) - else: - x_in = x - - # DON'T use ReLU here! In experiments, I find ReLU hard to converge. - alpha = torch.tanh(self.linear1(x_in)) - # alpha = F.relu(self.linear1(x_in)) - alpha = torch.softmax(self.linear2(alpha), dim=2) - mean = torch.sum(alpha * x, dim=2) - residuals = torch.sum(alpha * (x**2), dim=2) - mean**2 - std = torch.sqrt(residuals.clamp(min=1e-9)) - return torch.cat([mean, std], dim=1) - - -class ECAPA_TDNN_WAVLLM(nn.Module): - def __init__( - self, - feat_dim=80, - channels=512, - emb_dim=192, - global_context_att=False, - sr=16000, - ssl_model_path=None, - ): - super().__init__() - self.sr = sr - - if ssl_model_path is None: - self.feature_extract = torch.hub.load("s3prl/s3prl", "wavlm_large") - else: - self.feature_extract = torch.hub.load( - os.path.dirname(ssl_model_path), - "wavlm_local", - source="local", - ckpt=ssl_model_path, - ) - - if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( - self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention" - ): - self.feature_extract.model.encoder.layers[ - 23 - ].self_attn.fp32_attention = False - if len(self.feature_extract.model.encoder.layers) == 24 and hasattr( - self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention" - ): - self.feature_extract.model.encoder.layers[ - 11 - ].self_attn.fp32_attention = False - - self.feat_num = self.get_feat_num() - self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) - - self.instance_norm = nn.InstanceNorm1d(feat_dim) - # self.channels = [channels] * 4 + [channels * 3] - self.channels = [channels] * 4 + [1536] - - self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2) - self.layer2 = SE_Res2Block( - self.channels[0], - self.channels[1], - kernel_size=3, - stride=1, - padding=2, - dilation=2, - scale=8, - se_bottleneck_dim=128, - ) - self.layer3 = SE_Res2Block( - self.channels[1], - self.channels[2], - kernel_size=3, - stride=1, - padding=3, - dilation=3, - scale=8, - se_bottleneck_dim=128, - ) - self.layer4 = SE_Res2Block( - self.channels[2], - self.channels[3], - kernel_size=3, - stride=1, - padding=4, - dilation=4, - scale=8, - se_bottleneck_dim=128, - ) - - # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1) - cat_channels = channels * 3 - self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) - self.pooling = AttentiveStatsPool( - self.channels[-1], - attention_channels=128, - global_context_att=global_context_att, - ) - self.bn = nn.BatchNorm1d(self.channels[-1] * 2) - self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) - - def get_feat_num(self): - self.feature_extract.eval() - wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)] - with torch.no_grad(): - features = self.feature_extract(wav) - select_feature = features["hidden_states"] - if isinstance(select_feature, (list, tuple)): - return len(select_feature) - else: - return 1 - - def get_feat(self, x): - with torch.no_grad(): - x = self.feature_extract([sample for sample in x]) - - x = x["hidden_states"] - if isinstance(x, (list, tuple)): - x = torch.stack(x, dim=0) - else: - x = x.unsqueeze(0) - norm_weights = ( - F.softmax(self.feature_weight, dim=-1) - .unsqueeze(-1) - .unsqueeze(-1) - .unsqueeze(-1) - ) - x = (norm_weights * x).sum(dim=0) - x = torch.transpose(x, 1, 2) + 1e-6 - - x = self.instance_norm(x) - return x - - def forward(self, x): - x = self.get_feat(x) - - out1 = self.layer1(x) - out2 = self.layer2(out1) - out3 = self.layer3(out2) - out4 = self.layer4(out3) - - out = torch.cat([out2, out3, out4], dim=1) - out = F.relu(self.conv(out)) - out = self.bn(self.pooling(out)) - out = self.linear(out) - - return out - - -if __name__ == "__main__": - parser = get_parser() - args = parser.parse_args() - SIM = SpeakerSimilarity( - sv_model_path=args.sv_model_path, ssl_model_path=args.ssl_model_path - ) - score = SIM.score(args.eval_path, args.test_list) - logging.info(f"SIM-o score: {score:.3f}") diff --git a/egs/zipvoice/local/evaluate_utmos.py b/egs/zipvoice/local/evaluate_utmos.py deleted file mode 100644 index 369e139c1..000000000 --- a/egs/zipvoice/local/evaluate_utmos.py +++ /dev/null @@ -1,294 +0,0 @@ -""" -Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system -adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo - -# Download model checkpoints -wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt -wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt -""" - -import argparse -import logging -import os - -import fairseq -import librosa -import numpy as np -import pytorch_lightning as pl -import soundfile as sf -import torch -import torch.nn as nn -from tqdm import tqdm - -logging.basicConfig(level=logging.INFO) - - -def get_parser(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--wav-path", type=str, help="path of the evaluated speech directory" - ) - parser.add_argument( - "--utmos-model-path", - type=str, - default="model/huggingface/utmos/utmos.pt", - help="path of the UTMOS model", - ) - parser.add_argument( - "--ssl-model-path", - type=str, - default="model/huggingface/utmos/wav2vec_small.pt", - help="path of the wav2vec SSL model", - ) - return parser - - -class UTMOSScore: - """Predicting score for each audio clip.""" - - def __init__(self, utmos_model_path, ssl_model_path): - self.sample_rate = 16000 - self.device = ( - torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - ) - self.model = ( - BaselineLightningModule.load_from_checkpoint( - utmos_model_path, ssl_model_path=ssl_model_path - ) - .eval() - .to(self.device) - ) - - def score(self, wavs: torch.Tensor) -> torch.Tensor: - """ - Args: - wavs: waveforms to be evaluated. When len(wavs) == 1 or 2, - the model processes the input as a single audio clip. The model - performs batch processing when len(wavs) == 3. - """ - if len(wavs.shape) == 1: - out_wavs = wavs.unsqueeze(0).unsqueeze(0) - elif len(wavs.shape) == 2: - out_wavs = wavs.unsqueeze(0) - elif len(wavs.shape) == 3: - out_wavs = wavs - else: - raise ValueError("Dimension of input tensor needs to be <= 3.") - bs = out_wavs.shape[0] - batch = { - "wav": out_wavs, - "domains": torch.zeros(bs, dtype=torch.int).to(self.device), - "judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288, - } - with torch.no_grad(): - output = self.model(batch) - - return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3 - - def score_dir(self, dir, dtype="float32"): - def _load_speech_task(fname, sample_rate): - - wav_data, sr = sf.read(fname, dtype=dtype) - if sr != sample_rate: - wav_data = librosa.resample( - wav_data, orig_sr=sr, target_sr=self.sample_rate - ) - wav_data = torch.from_numpy(wav_data) - - return wav_data - - score_lst = [] - for fname in tqdm(os.listdir(dir)): - speech = _load_speech_task(os.path.join(dir, fname), self.sample_rate) - speech = speech.to(self.device) - with torch.no_grad(): - score = self.score(speech) - score_lst.append(score.item()) - return np.mean(score_lst) - - -def load_ssl_model(ckpt_path="wav2vec_small.pt"): - SSL_OUT_DIM = 768 - model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( - [ckpt_path] - ) - ssl_model = model[0] - ssl_model.remove_pretraining_modules() - return SSL_model(ssl_model, SSL_OUT_DIM) - - -class BaselineLightningModule(pl.LightningModule): - def __init__(self, ssl_model_path): - super().__init__() - self.construct_model(ssl_model_path) - self.save_hyperparameters() - - def construct_model(self, ssl_model_path): - self.feature_extractors = nn.ModuleList( - [ - load_ssl_model(ckpt_path=ssl_model_path), - DomainEmbedding(3, 128), - ] - ) - output_dim = sum( - [ - feature_extractor.get_output_dim() - for feature_extractor in self.feature_extractors - ] - ) - output_layers = [ - LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim) - ] - output_dim = output_layers[-1].get_output_dim() - output_layers.append( - Projection( - hidden_dim=2048, - activation=torch.nn.ReLU(), - range_clipping=False, - input_dim=output_dim, - ) - ) - - self.output_layers = nn.ModuleList(output_layers) - - def forward(self, inputs): - outputs = {} - for feature_extractor in self.feature_extractors: - outputs.update(feature_extractor(inputs)) - x = outputs - for output_layer in self.output_layers: - x = output_layer(x, inputs) - return x - - -class SSL_model(nn.Module): - def __init__(self, ssl_model, ssl_out_dim) -> None: - super(SSL_model, self).__init__() - self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim - - def forward(self, batch): - wav = batch["wav"] - wav = wav.squeeze(1) # [batches, wav_len] - res = self.ssl_model(wav, mask=False, features_only=True) - x = res["x"] - return {"ssl-feature": x} - - def get_output_dim(self): - return self.ssl_out_dim - - -class DomainEmbedding(nn.Module): - def __init__(self, n_domains, domain_dim) -> None: - super().__init__() - self.embedding = nn.Embedding(n_domains, domain_dim) - self.output_dim = domain_dim - - def forward(self, batch): - return {"domain-feature": self.embedding(batch["domains"])} - - def get_output_dim(self): - return self.output_dim - - -class LDConditioner(nn.Module): - """ - Conditions ssl output by listener embedding - """ - - def __init__(self, input_dim, judge_dim, num_judges=None): - super().__init__() - self.input_dim = input_dim - self.judge_dim = judge_dim - self.num_judges = num_judges - assert num_judges != None - self.judge_embedding = nn.Embedding(num_judges, self.judge_dim) - # concat [self.output_layer, phoneme features] - - self.decoder_rnn = nn.LSTM( - input_size=self.input_dim + self.judge_dim, - hidden_size=512, - num_layers=1, - batch_first=True, - bidirectional=True, - ) # linear? - self.out_dim = self.decoder_rnn.hidden_size * 2 - - def get_output_dim(self): - return self.out_dim - - def forward(self, x, batch): - judge_ids = batch["judge_id"] - if "phoneme-feature" in x.keys(): - concatenated_feature = torch.cat( - ( - x["ssl-feature"], - x["phoneme-feature"] - .unsqueeze(1) - .expand(-1, x["ssl-feature"].size(1), -1), - ), - dim=2, - ) - else: - concatenated_feature = x["ssl-feature"] - if "domain-feature" in x.keys(): - concatenated_feature = torch.cat( - ( - concatenated_feature, - x["domain-feature"] - .unsqueeze(1) - .expand(-1, concatenated_feature.size(1), -1), - ), - dim=2, - ) - if judge_ids != None: - concatenated_feature = torch.cat( - ( - concatenated_feature, - self.judge_embedding(judge_ids) - .unsqueeze(1) - .expand(-1, concatenated_feature.size(1), -1), - ), - dim=2, - ) - decoder_output, (h, c) = self.decoder_rnn(concatenated_feature) - return decoder_output - - -class Projection(nn.Module): - def __init__(self, input_dim, hidden_dim, activation, range_clipping=False): - super(Projection, self).__init__() - self.range_clipping = range_clipping - output_dim = 1 - if range_clipping: - self.proj = nn.Tanh() - - self.net = nn.Sequential( - nn.Linear(input_dim, hidden_dim), - activation, - nn.Dropout(0.3), - nn.Linear(hidden_dim, output_dim), - ) - self.output_dim = output_dim - - def forward(self, x, batch): - output = self.net(x) - - # range clipping - if self.range_clipping: - return self.proj(output) * 2.0 + 3 - else: - return output - - def get_output_dim(self): - return self.output_dim - - -if __name__ == "__main__": - parser = get_parser() - args = parser.parse_args() - UTMOS = UTMOSScore( - utmos_model_path=args.utmos_model_path, ssl_model_path=args.ssl_model_path - ) - score = UTMOS.score_dir(args.wav_path) - logging.info(f"UTMOS score: {score:.2f}") diff --git a/egs/zipvoice/local/evaluate_wer_hubert.py b/egs/zipvoice/local/evaluate_wer_hubert.py deleted file mode 100644 index d30346e67..000000000 --- a/egs/zipvoice/local/evaluate_wer_hubert.py +++ /dev/null @@ -1,172 +0,0 @@ -""" -Calculate WER with Hubert models. -""" -import argparse -import os -import re -from pathlib import Path - -import librosa -import numpy as np -import soundfile as sf -import torch -from jiwer import compute_measures -from tqdm import tqdm -from transformers import pipeline - - -def get_parser(): - parser = argparse.ArgumentParser() - - parser.add_argument("--wav-path", type=str, help="path of the speech directory") - parser.add_argument( - "--decode-path", - type=str, - default=None, - help="path of the output file of WER information", - ) - parser.add_argument( - "--model-path", - type=str, - default=None, - help="path of the local hubert model, e.g., model/huggingface/hubert-large-ls960-ft", - ) - parser.add_argument( - "--test-list", - type=str, - default="test.tsv", - help="path of the transcript tsv file, where the first column " - "is the wav name and the last column is the transcript", - ) - parser.add_argument( - "--batch-size", type=int, default=16, help="decoding batch size" - ) - return parser - - -def post_process(text: str): - text = text.replace("‘", "'") - text = text.replace("’", "'") - text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower()) - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text - - -def process_one(hypo, truth): - truth = post_process(truth) - hypo = post_process(hypo) - - measures = compute_measures(truth, hypo) - word_num = len(truth.split(" ")) - wer = measures["wer"] - subs = measures["substitutions"] - dele = measures["deletions"] - inse = measures["insertions"] - return (truth, hypo, wer, subs, dele, inse, word_num) - - -class SpeechEvalDataset(torch.utils.data.Dataset): - def __init__(self, wav_path: str, test_list: str): - super().__init__() - self.wav_name = [] - self.wav_paths = [] - self.transcripts = [] - with Path(test_list).open("r", encoding="utf8") as f: - meta = [item.split("\t") for item in f.read().rstrip().split("\n")] - for item in meta: - self.wav_name.append(item[0]) - self.wav_paths.append(Path(wav_path, item[0] + ".wav")) - self.transcripts.append(item[-1]) - - def __len__(self): - return len(self.wav_paths) - - def __getitem__(self, index: int): - wav, sampling_rate = sf.read(self.wav_paths[index]) - item = { - "array": librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000), - "sampling_rate": 16000, - "reference": self.transcripts[index], - "wav_name": self.wav_name[index], - } - return item - - -def main(test_list, wav_path, model_path, decode_path, batch_size, device): - - if model_path is not None: - pipe = pipeline( - "automatic-speech-recognition", - model=model_path, - device=device, - tokenizer=model_path, - ) - else: - pipe = pipeline( - "automatic-speech-recognition", - model="facebook/hubert-large-ls960-ft", - device=device, - ) - - dataset = SpeechEvalDataset(wav_path, test_list) - - bar = tqdm( - pipe( - dataset, - generate_kwargs={"language": "english", "task": "transcribe"}, - batch_size=batch_size, - ), - total=len(dataset), - ) - - wers = [] - inses = [] - deles = [] - subses = [] - word_nums = 0 - if decode_path: - decode_dir = os.path.dirname(decode_path) - if not os.path.exists(decode_dir): - os.makedirs(decode_dir) - fout = open(decode_path, "w") - for out in bar: - wav_name = out["wav_name"][0] - transcription = post_process(out["text"].strip()) - text_ref = post_process(out["reference"][0].strip()) - truth, hypo, wer, subs, dele, inse, word_num = process_one( - transcription, text_ref - ) - if decode_path: - fout.write(f"{wav_name}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n") - wers.append(float(wer)) - inses.append(float(inse)) - deles.append(float(dele)) - subses.append(float(subs)) - word_nums += word_num - - wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3) - subs = round(np.mean(subses) * 100, 3) - dele = round(np.mean(deles) * 100, 3) - inse = round(np.mean(inses) * 100, 3) - print(f"WER: {wer}%\n") - if decode_path: - fout.write(f"WER: {wer}%\n") - fout.flush() - - -if __name__ == "__main__": - parser = get_parser() - args = parser.parse_args() - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - else: - device = torch.device("cpu") - main( - args.test_list, - args.wav_path, - args.model_path, - args.decode_path, - args.batch_size, - device, - ) diff --git a/egs/zipvoice/local/evaluate_wer_seedtts.py b/egs/zipvoice/local/evaluate_wer_seedtts.py deleted file mode 100644 index f7e256387..000000000 --- a/egs/zipvoice/local/evaluate_wer_seedtts.py +++ /dev/null @@ -1,181 +0,0 @@ -""" -Calculate WER with Whisper-large-v3 or Paraformer models, -following Seed-TTS https://github.com/BytedanceSpeech/seed-tts-eval -""" - -import argparse -import os -import string - -import numpy as np -import scipy -import soundfile as sf -import torch -import zhconv -from funasr import AutoModel -from jiwer import compute_measures -from tqdm import tqdm -from transformers import WhisperForConditionalGeneration, WhisperProcessor -from zhon.hanzi import punctuation - - -def get_parser(): - parser = argparse.ArgumentParser() - - parser.add_argument("--wav-path", type=str, help="path of the speech directory") - parser.add_argument( - "--decode-path", - type=str, - default=None, - help="path of the output file of WER information", - ) - parser.add_argument( - "--model-path", - type=str, - default=None, - help="path of the local whisper and paraformer model, " - "e.g., whisper: model/huggingface/whisper-large-v3/, " - "paraformer: model/huggingface/paraformer-zh/", - ) - parser.add_argument( - "--test-list", - type=str, - default="test.tsv", - help="path of the transcript tsv file, where the first column " - "is the wav name and the last column is the transcript", - ) - parser.add_argument("--lang", type=str, help="decoded language, zh or en") - return parser - - -def load_en_model(model_path): - if model_path is None: - model_path = "openai/whisper-large-v3" - processor = WhisperProcessor.from_pretrained(model_path) - model = WhisperForConditionalGeneration.from_pretrained(model_path) - return processor, model - - -def load_zh_model(model_path): - if model_path is None: - model_path = "paraformer-zh" - model = AutoModel(model=model_path) - return model - - -def process_one(hypo, truth, lang): - punctuation_all = punctuation + string.punctuation - for x in punctuation_all: - if x == "'": - continue - truth = truth.replace(x, "") - hypo = hypo.replace(x, "") - - truth = truth.replace(" ", " ") - hypo = hypo.replace(" ", " ") - - if lang == "zh": - truth = " ".join([x for x in truth]) - hypo = " ".join([x for x in hypo]) - elif lang == "en": - truth = truth.lower() - hypo = hypo.lower() - else: - raise NotImplementedError - - measures = compute_measures(truth, hypo) - word_num = len(truth.split(" ")) - wer = measures["wer"] - subs = measures["substitutions"] - dele = measures["deletions"] - inse = measures["insertions"] - return (truth, hypo, wer, subs, dele, inse, word_num) - - -def main(test_list, wav_path, model_path, decode_path, lang, device): - if lang == "en": - processor, model = load_en_model(model_path) - model.to(device) - elif lang == "zh": - model = load_zh_model(model_path) - params = [] - for line in open(test_list).readlines(): - line = line.strip() - items = line.split("\t") - wav_name, text_ref = items[0], items[-1] - file_path = os.path.join(wav_path, wav_name + ".wav") - assert os.path.exists(file_path), f"{file_path}" - - params.append((file_path, text_ref)) - wers = [] - inses = [] - deles = [] - subses = [] - word_nums = 0 - if decode_path: - decode_dir = os.path.dirname(decode_path) - if not os.path.exists(decode_dir): - os.makedirs(decode_dir) - fout = open(decode_path, "w") - for wav_path, text_ref in tqdm(params): - if lang == "en": - wav, sr = sf.read(wav_path) - if sr != 16000: - wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr)) - input_features = processor( - wav, sampling_rate=16000, return_tensors="pt" - ).input_features - input_features = input_features.to(device) - forced_decoder_ids = processor.get_decoder_prompt_ids( - language="english", task="transcribe" - ) - predicted_ids = model.generate( - input_features, forced_decoder_ids=forced_decoder_ids - ) - transcription = processor.batch_decode( - predicted_ids, skip_special_tokens=True - )[0] - elif lang == "zh": - res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True) - transcription = res[0]["text"] - transcription = zhconv.convert(transcription, "zh-cn") - - truth, hypo, wer, subs, dele, inse, word_num = process_one( - transcription, text_ref, lang - ) - if decode_path: - fout.write(f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n") - wers.append(float(wer)) - inses.append(float(inse)) - deles.append(float(dele)) - subses.append(float(subs)) - word_nums += word_num - - wer_avg = round(np.mean(wers) * 100, 3) - wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3) - subs = round(np.mean(subses) * 100, 3) - dele = round(np.mean(deles) * 100, 3) - inse = round(np.mean(inses) * 100, 3) - print(f"Seed-TTS WER: {wer_avg}%\n") - print(f"WER: {wer}%\n") - if decode_path: - fout.write(f"SeedTTS WER: {wer_avg}%\n") - fout.write(f"WER: {wer}%\n") - fout.flush() - - -if __name__ == "__main__": - parser = get_parser() - args = parser.parse_args() - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - else: - device = torch.device("cpu") - main( - args.test_list, - args.wav_path, - args.model_path, - args.decode_path, - args.lang, - device, - ) diff --git a/egs/zipvoice/local/feature.py b/egs/zipvoice/local/feature.py deleted file mode 120000 index 08ef7d228..000000000 --- a/egs/zipvoice/local/feature.py +++ /dev/null @@ -1 +0,0 @@ -../zipvoice/feature.py \ No newline at end of file diff --git a/egs/zipvoice/local/pinyin.txt b/egs/zipvoice/local/pinyin.txt deleted file mode 100644 index cd8d14dc3..000000000 --- a/egs/zipvoice/local/pinyin.txt +++ /dev/null @@ -1,1550 +0,0 @@ -a -a1 -a2 -a3 -a4 -ai1 -ai2 -ai3 -ai4 -an1 -an2 -an3 -an4 -ang1 -ang2 -ang3 -ang4 -ao1 -ao2 -ao3 -ao4 -ba -ba1 -ba2 -ba3 -ba4 -bai -bai1 -bai2 -bai3 -bai4 -ban -ban1 -ban3 -ban4 -bang1 -bang3 -bang4 -bao1 -bao2 -bao3 -bao4 -bei -bei1 -bei3 -bei4 -ben1 -ben3 -ben4 -beng -beng1 -beng2 -beng3 -beng4 -bi1 -bi2 -bi3 -bi4 -bian -bian1 -bian3 -bian4 -biang2 -biao1 -biao3 -biao4 -bie1 -bie2 -bie3 -bie4 -bin -bin1 -bin3 -bin4 -bing1 -bing3 -bing4 -bo -bo1 -bo2 -bo3 -bo4 -bu1 -bu2 -bu3 -bu4 -ca1 -ca3 -ca4 -cai1 -cai2 -cai3 -cai4 -can1 -can2 -can3 -can4 -cang1 -cang2 -cang3 -cang4 -cao1 -cao2 -cao3 -cao4 -ce4 -cei4 -cen1 -cen2 -ceng1 -ceng2 -ceng4 -cha1 -cha2 -cha3 -cha4 -chai1 -chai2 -chai3 -chai4 -chan1 -chan2 -chan3 -chan4 -chang -chang1 -chang2 -chang3 -chang4 -chao1 -chao2 -chao3 -chao4 -che1 -che2 -che3 -che4 -chen -chen1 -chen2 -chen3 -chen4 -cheng1 -cheng2 -cheng3 -cheng4 -chi -chi1 -chi2 -chi3 -chi4 -chong1 -chong2 -chong3 -chong4 -chou1 -chou2 -chou3 -chou4 -chu -chu1 -chu2 -chu3 -chu4 -chua1 -chua3 -chua4 -chuai1 -chuai2 -chuai3 -chuai4 -chuan1 -chuan2 -chuan3 -chuan4 -chuang1 -chuang2 -chuang3 -chuang4 -chui1 -chui2 -chui3 -chui4 -chun1 -chun2 -chun3 -chuo1 -chuo4 -ci1 -ci2 -ci3 -ci4 -cong1 -cong2 -cong3 -cong4 -cou1 -cou2 -cou3 -cou4 -cu1 -cu2 -cu3 -cu4 -cuan1 -cuan2 -cuan4 -cui -cui1 -cui3 -cui4 -cun1 -cun2 -cun3 -cun4 -cuo1 -cuo2 -cuo3 -cuo4 -da -da1 -da2 -da3 -da4 -dai -dai1 -dai3 -dai4 -dan1 -dan3 -dan4 -dang -dang1 -dang3 -dang4 -dao1 -dao2 -dao3 -dao4 -de -de1 -de2 -dei1 -dei3 -den4 -deng1 -deng3 -deng4 -di1 -di2 -di3 -di4 -dia3 -dian1 -dian2 -dian3 -dian4 -diao1 -diao3 -diao4 -die1 -die2 -die3 -die4 -din4 -ding1 -ding3 -ding4 -diu1 -dong1 -dong3 -dong4 -dou1 -dou3 -dou4 -du1 -du2 -du3 -du4 -duan1 -duan3 -duan4 -dui1 -dui3 -dui4 -dun1 -dun3 -dun4 -duo -duo1 -duo2 -duo3 -duo4 -e -e1 -e2 -e3 -e4 -ei1 -ei2 -ei3 -ei4 -en1 -en3 -en4 -eng1 -er -er2 -er3 -er4 -fa -fa1 -fa2 -fa3 -fa4 -fan1 -fan2 -fan3 -fan4 -fang -fang1 -fang2 -fang3 -fang4 -fei1 -fei2 -fei3 -fei4 -fen1 -fen2 -fen3 -fen4 -feng1 -feng2 -feng3 -feng4 -fiao4 -fo2 -fou1 -fou2 -fou3 -fu -fu1 -fu2 -fu3 -fu4 -ga1 -ga2 -ga3 -ga4 -gai1 -gai3 -gai4 -gan1 -gan3 -gan4 -gang1 -gang3 -gang4 -gao1 -gao3 -gao4 -ge1 -ge2 -ge3 -ge4 -gei3 -gen1 -gen2 -gen3 -gen4 -geng1 -geng3 -geng4 -gong -gong1 -gong3 -gong4 -gou1 -gou3 -gou4 -gu -gu1 -gu2 -gu3 -gu4 -gua1 -gua2 -gua3 -gua4 -guai1 -guai3 -guai4 -guan1 -guan3 -guan4 -guang -guang1 -guang3 -guang4 -gui1 -gui3 -gui4 -gun3 -gun4 -guo -guo1 -guo2 -guo3 -guo4 -ha1 -ha2 -ha3 -ha4 -hai -hai1 -hai2 -hai3 -hai4 -han -han1 -han2 -han3 -han4 -hang1 -hang2 -hang3 -hang4 -hao1 -hao2 -hao3 -hao4 -he1 -he2 -he3 -he4 -hei1 -hen1 -hen2 -hen3 -hen4 -heng1 -heng2 -heng4 -hm -hng -hong1 -hong2 -hong3 -hong4 -hou1 -hou2 -hou3 -hou4 -hu -hu1 -hu2 -hu3 -hu4 -hua1 -hua2 -hua4 -huai -huai2 -huai4 -huan1 -huan2 -huan3 -huan4 -huang -huang1 -huang2 -huang3 -huang4 -hui -hui1 -hui2 -hui3 -hui4 -hun1 -hun2 -hun3 -hun4 -huo -huo1 -huo2 -huo3 -huo4 -ji1 -ji2 -ji3 -ji4 -jia -jia1 -jia2 -jia3 -jia4 -jian -jian1 -jian3 -jian4 -jiang -jiang1 -jiang3 -jiang4 -jiao -jiao1 -jiao2 -jiao3 -jiao4 -jie -jie1 -jie2 -jie3 -jie4 -jin1 -jin3 -jin4 -jing -jing1 -jing3 -jing4 -jiong1 -jiong3 -jiong4 -jiu -jiu1 -jiu2 -jiu3 -jiu4 -ju -ju1 -ju2 -ju3 -ju4 -juan1 -juan3 -juan4 -jue1 -jue2 -jue3 -jue4 -jun1 -jun3 -jun4 -ka1 -ka3 -kai1 -kai3 -kai4 -kan1 -kan3 -kan4 -kang1 -kang2 -kang3 -kang4 -kao1 -kao3 -kao4 -ke -ke1 -ke2 -ke3 -ke4 -kei1 -ken1 -ken3 -ken4 -keng1 -keng3 -kong1 -kong3 -kong4 -kou1 -kou3 -kou4 -ku1 -ku2 -ku3 -ku4 -kua1 -kua3 -kua4 -kuai3 -kuai4 -kuan1 -kuan3 -kuang1 -kuang2 -kuang3 -kuang4 -kui1 -kui2 -kui3 -kui4 -kun -kun1 -kun3 -kun4 -kuo4 -la -la1 -la2 -la3 -la4 -lai2 -lai3 -lai4 -lan2 -lan3 -lan4 -lang -lang1 -lang2 -lang3 -lang4 -lao -lao1 -lao2 -lao3 -lao4 -le -le1 -le4 -lei -lei1 -lei2 -lei3 -lei4 -len4 -leng1 -leng2 -leng3 -leng4 -li -li1 -li2 -li3 -li4 -lia3 -lian2 -lian3 -lian4 -liang -liang2 -liang3 -liang4 -liao1 -liao2 -liao3 -liao4 -lie -lie1 -lie2 -lie3 -lie4 -lin1 -lin2 -lin3 -lin4 -ling -ling1 -ling2 -ling3 -ling4 -liu1 -liu2 -liu3 -liu4 -lo -long1 -long2 -long3 -long4 -lou -lou1 -lou2 -lou3 -lou4 -lu -lu1 -lu2 -lu3 -lu4 -luan2 -luan3 -luan4 -lun1 -lun2 -lun3 -lun4 -luo -luo1 -luo2 -luo3 -luo4 -lv2 -lv3 -lv4 -lve3 -lve4 -m1 -m2 -m4 -ma -ma1 -ma2 -ma3 -ma4 -mai2 -mai3 -mai4 -man1 -man2 -man3 -man4 -mang1 -mang2 -mang3 -mang4 -mao1 -mao2 -mao3 -mao4 -me -me1 -mei2 -mei3 -mei4 -men -men1 -men2 -men4 -meng -meng1 -meng2 -meng3 -meng4 -mi1 -mi2 -mi3 -mi4 -mian2 -mian3 -mian4 -miao1 -miao2 -miao3 -miao4 -mie -mie1 -mie2 -mie4 -min -min2 -min3 -ming -ming2 -ming3 -ming4 -miu3 -miu4 -mo -mo1 -mo2 -mo3 -mo4 -mou1 -mou2 -mou3 -mou4 -mu2 -mu3 -mu4 -n -n2 -n3 -n4 -na -na1 -na2 -na3 -na4 -nai2 -nai3 -nai4 -nan1 -nan2 -nan3 -nan4 -nang -nang1 -nang2 -nang3 -nang4 -nao1 -nao2 -nao3 -nao4 -ne -ne2 -ne4 -nei2 -nei3 -nei4 -nen4 -neng2 -neng3 -neng4 -ng -ng2 -ng3 -ng4 -ni1 -ni2 -ni3 -ni4 -nia1 -nian1 -nian2 -nian3 -nian4 -niang2 -niang3 -niang4 -niao3 -niao4 -nie1 -nie2 -nie3 -nie4 -nin -nin2 -nin3 -ning2 -ning3 -ning4 -niu1 -niu2 -niu3 -niu4 -nong2 -nong3 -nong4 -nou2 -nou3 -nou4 -nu2 -nu3 -nu4 -nuan2 -nuan3 -nuan4 -nun2 -nun4 -nuo2 -nuo3 -nuo4 -nv2 -nv3 -nv4 -nve4 -o -o1 -o2 -o3 -o4 -ou -ou1 -ou2 -ou3 -ou4 -pa1 -pa2 -pa3 -pa4 -pai1 -pai2 -pai3 -pai4 -pan1 -pan2 -pan3 -pan4 -pang1 -pang2 -pang3 -pang4 -pao1 -pao2 -pao3 -pao4 -pei1 -pei2 -pei3 -pei4 -pen1 -pen2 -pen3 -pen4 -peng1 -peng2 -peng3 -peng4 -pi1 -pi2 -pi3 -pi4 -pian1 -pian2 -pian3 -pian4 -piao1 -piao2 -piao3 -piao4 -pie1 -pie3 -pie4 -pin1 -pin2 -pin3 -pin4 -ping1 -ping2 -ping3 -ping4 -po -po1 -po2 -po3 -po4 -pou1 -pou2 -pou3 -pou4 -pu -pu1 -pu2 -pu3 -pu4 -qi -qi1 -qi2 -qi3 -qi4 -qia1 -qia2 -qia3 -qia4 -qian -qian1 -qian2 -qian3 -qian4 -qiang1 -qiang2 -qiang3 -qiang4 -qiao1 -qiao2 -qiao3 -qiao4 -qie1 -qie2 -qie3 -qie4 -qin1 -qin2 -qin3 -qin4 -qing -qing1 -qing2 -qing3 -qing4 -qiong1 -qiong2 -qiong4 -qiu1 -qiu2 -qiu3 -qiu4 -qu -qu1 -qu2 -qu3 -qu4 -quan -quan1 -quan2 -quan3 -quan4 -que1 -que2 -que4 -qun1 -qun2 -qun3 -ran2 -ran3 -ran4 -rang1 -rang2 -rang3 -rang4 -rao2 -rao3 -rao4 -re2 -re3 -re4 -ren2 -ren3 -ren4 -reng1 -reng2 -reng4 -ri4 -rong -rong1 -rong2 -rong3 -rong4 -rou2 -rou3 -rou4 -ru -ru2 -ru3 -ru4 -rua2 -ruan2 -ruan3 -ruan4 -rui2 -rui3 -rui4 -run2 -run3 -run4 -ruo2 -ruo4 -sa -sa1 -sa3 -sa4 -sai1 -sai3 -sai4 -san -san1 -san3 -san4 -sang1 -sang3 -sang4 -sao1 -sao3 -sao4 -se1 -se4 -sen1 -sen3 -seng1 -seng4 -sha -sha1 -sha2 -sha3 -sha4 -shai1 -shai3 -shai4 -shan1 -shan2 -shan3 -shan4 -shang -shang1 -shang3 -shang4 -shao1 -shao2 -shao3 -shao4 -she1 -she2 -she3 -she4 -shei2 -shen1 -shen2 -shen3 -shen4 -sheng1 -sheng2 -sheng3 -sheng4 -shi -shi1 -shi2 -shi3 -shi4 -shou -shou1 -shou2 -shou3 -shou4 -shu1 -shu2 -shu3 -shu4 -shua1 -shua3 -shua4 -shuai1 -shuai3 -shuai4 -shuan1 -shuan4 -shuang1 -shuang3 -shuang4 -shui -shui2 -shui3 -shui4 -shun3 -shun4 -shuo1 -shuo2 -shuo4 -si -si1 -si2 -si3 -si4 -song1 -song2 -song3 -song4 -sou1 -sou3 -sou4 -su1 -su2 -su3 -su4 -suan1 -suan3 -suan4 -sui1 -sui2 -sui3 -sui4 -sun1 -sun3 -sun4 -suo -suo1 -suo2 -suo3 -suo4 -ta -ta1 -ta2 -ta3 -ta4 -tai -tai1 -tai2 -tai3 -tai4 -tan1 -tan2 -tan3 -tan4 -tang1 -tang2 -tang3 -tang4 -tao1 -tao2 -tao3 -tao4 -te -te4 -tei1 -teng1 -teng2 -teng4 -ti -ti1 -ti2 -ti3 -ti4 -tian1 -tian2 -tian3 -tian4 -tiao -tiao1 -tiao2 -tiao3 -tiao4 -tie1 -tie2 -tie3 -tie4 -ting1 -ting2 -ting3 -ting4 -tong1 -tong2 -tong3 -tong4 -tou -tou1 -tou2 -tou3 -tou4 -tu -tu1 -tu2 -tu3 -tu4 -tuan1 -tuan2 -tuan3 -tuan4 -tui1 -tui2 -tui3 -tui4 -tun1 -tun2 -tun3 -tun4 -tuo1 -tuo2 -tuo3 -tuo4 -wa -wa1 -wa2 -wa3 -wa4 -wai -wai1 -wai3 -wai4 -wan1 -wan2 -wan3 -wan4 -wang1 -wang2 -wang3 -wang4 -wei -wei1 -wei2 -wei3 -wei4 -wen -wen1 -wen2 -wen3 -wen4 -weng1 -weng3 -weng4 -wo1 -wo3 -wo4 -wong4 -wu -wu1 -wu2 -wu3 -wu4 -xi1 -xi2 -xi3 -xi4 -xia1 -xia2 -xia3 -xia4 -xian -xian1 -xian2 -xian3 -xian4 -xiang1 -xiang2 -xiang3 -xiang4 -xiao -xiao1 -xiao2 -xiao3 -xiao4 -xie1 -xie2 -xie3 -xie4 -xin -xin1 -xin2 -xin3 -xin4 -xing -xing1 -xing2 -xing3 -xing4 -xiong1 -xiong2 -xiong3 -xiong4 -xiu1 -xiu2 -xiu3 -xiu4 -xu -xu1 -xu2 -xu3 -xu4 -xuan1 -xuan2 -xuan3 -xuan4 -xue1 -xue2 -xue3 -xue4 -xun1 -xun2 -xun4 -ya -ya1 -ya2 -ya3 -ya4 -yan1 -yan2 -yan3 -yan4 -yang -yang1 -yang2 -yang3 -yang4 -yao1 -yao2 -yao3 -yao4 -ye -ye1 -ye2 -ye3 -ye4 -yi -yi1 -yi2 -yi3 -yi4 -yin -yin1 -yin2 -yin3 -yin4 -ying1 -ying2 -ying3 -ying4 -yo -yo1 -yong1 -yong2 -yong3 -yong4 -you -you1 -you2 -you3 -you4 -yu -yu1 -yu2 -yu3 -yu4 -yuan1 -yuan2 -yuan3 -yuan4 -yue1 -yue2 -yue3 -yue4 -yun -yun1 -yun2 -yun3 -yun4 -za1 -za2 -za3 -za4 -zai1 -zai3 -zai4 -zan -zan1 -zan2 -zan3 -zan4 -zang1 -zang3 -zang4 -zao1 -zao2 -zao3 -zao4 -ze -ze2 -ze4 -zei2 -zen -zen1 -zen3 -zen4 -zeng1 -zeng3 -zeng4 -zha -zha1 -zha2 -zha3 -zha4 -zhai1 -zhai2 -zhai3 -zhai4 -zhan1 -zhan2 -zhan3 -zhan4 -zhang -zhang1 -zhang3 -zhang4 -zhao -zhao1 -zhao2 -zhao3 -zhao4 -zhe -zhe1 -zhe2 -zhe3 -zhe4 -zhei4 -zhen1 -zhen2 -zhen3 -zhen4 -zheng1 -zheng3 -zheng4 -zhi -zhi1 -zhi2 -zhi3 -zhi4 -zhong1 -zhong3 -zhong4 -zhou1 -zhou2 -zhou3 -zhou4 -zhu1 -zhu2 -zhu3 -zhu4 -zhua1 -zhua3 -zhuai1 -zhuai3 -zhuai4 -zhuan1 -zhuan2 -zhuan3 -zhuan4 -zhuang1 -zhuang3 -zhuang4 -zhui1 -zhui3 -zhui4 -zhun1 -zhun3 -zhun4 -zhuo -zhuo1 -zhuo2 -zhuo4 -zi -zi1 -zi2 -zi3 -zi4 -zong -zong1 -zong3 -zong4 -zou1 -zou3 -zou4 -zu1 -zu2 -zu3 -zu4 -zuan1 -zuan3 -zuan4 -zui -zui1 -zui2 -zui3 -zui4 -zun1 -zun2 -zun3 -zun4 -zuo -zuo1 -zuo2 -zuo3 -zuo4 -ê1 -ê2 -ê3 -ê4 diff --git a/egs/zipvoice/local/prepare_token_file_emilia.py b/egs/zipvoice/local/prepare_token_file_emilia.py deleted file mode 100644 index 68af8d397..000000000 --- a/egs/zipvoice/local/prepare_token_file_emilia.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file generates the file that maps tokens to IDs. -""" - -import argparse -import logging -from pathlib import Path -from typing import List - -from piper_phonemize import get_espeak_map -from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--tokens", - type=Path, - default=Path("data/tokens_emilia.txt"), - help="Path to the dict that maps the text tokens to IDs", - ) - - parser.add_argument( - "--pinyin", - type=Path, - default=Path("local/pinyin.txt"), - help="Path to the all unique pinyin", - ) - - return parser.parse_args() - - -def get_pinyin_tokens(pinyin: Path) -> List[str]: - phones = set() - with open(pinyin, "r") as f: - for line in f: - x = line.strip() - initial = to_initials(x, strict=False) - # don't want to share tokens with espeak tokens, so use tone3 style - finals = to_finals_tone3(x, strict=False, neutral_tone_with_five=True) - if initial != "": - # don't want to share tokens with espeak tokens, so add a '0' after each initial - phones.add(initial + "0") - if finals != "": - phones.add(finals) - return sorted(phones) - - -def get_token2id(args): - """Get a dict that maps token to IDs, and save it to the given filename.""" - all_tokens = get_espeak_map() # token: [token_id] - all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()} - # sort by token_id - all_tokens = sorted(all_tokens.items(), key=lambda x: x[1]) - - all_pinyin = get_pinyin_tokens(args.pinyin) - with open(args.tokens, "w", encoding="utf-8") as f: - for token, token_id in all_tokens: - f.write(f"{token} {token_id}\n") - num_espeak_tokens = len(all_tokens) - for i, pinyin in enumerate(all_pinyin): - f.write(f"{pinyin} {num_espeak_tokens + i}\n") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - get_token2id(args) diff --git a/egs/zipvoice/local/prepare_token_file_libritts.py b/egs/zipvoice/local/prepare_token_file_libritts.py deleted file mode 100644 index 374b02613..000000000 --- a/egs/zipvoice/local/prepare_token_file_libritts.py +++ /dev/null @@ -1,31 +0,0 @@ -import re -from collections import Counter - -from lhotse import load_manifest_lazy - - -def prepare_tokens(manifest_file, token_file): - counter = Counter() - manifest = load_manifest_lazy(manifest_file) - for cut in manifest: - line = re.sub(r"\s+", " ", cut.supervisions[0].text) - counter.update(line) - - unique_chars = set(counter.keys()) - - if "_" in unique_chars: - unique_chars.remove("_") - - sorted_chars = sorted(unique_chars, key=lambda char: counter[char], reverse=True) - - result = ["_"] + sorted_chars - - with open(token_file, "w", encoding="utf-8") as file: - for index, char in enumerate(result): - file.write(f"{char} {index}\n") - - -if __name__ == "__main__": - manifest_file = "data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz" - output_token_file = "data/tokens_libritts.txt" - prepare_tokens(manifest_file, output_token_file) diff --git a/egs/zipvoice/local/preprocess_emilia.py b/egs/zipvoice/local/preprocess_emilia.py deleted file mode 100644 index 96a7ff228..000000000 --- a/egs/zipvoice/local/preprocess_emilia.py +++ /dev/null @@ -1,155 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao, -# Zengrui Jin, -# Wei Kang) -# 2024 Tsinghua University (authors: Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file reads the texts in given manifest and save the cleaned new cuts. -""" - -import argparse -import glob -import logging -import os -from concurrent.futures import ProcessPoolExecutor as Pool -from pathlib import Path -from typing import List - -from lhotse import CutSet, load_manifest_lazy -from tokenizer import ( - is_alphabet, - is_chinese, - is_hangul, - is_japanese, - tokenize_by_CJK_char, -) - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--subset", - type=str, - help="Subset of emilia, (ZH, EN, etc.)", - ) - - parser.add_argument( - "--jobs", - type=int, - default=20, - help="Number of jobs to processing.", - ) - - parser.add_argument( - "--source-dir", - type=str, - default="data/manifests/splits_raw", - help="The source directory of manifest files.", - ) - - parser.add_argument( - "--dest-dir", - type=str, - default="data/manifests/splits", - help="The destination directory of manifest files.", - ) - - return parser.parse_args() - - -def preprocess_emilia(file_name: str, input_dir: Path, output_dir: Path): - logging.info(f"Processing {file_name}") - if (output_dir / file_name).is_file(): - logging.info(f"{file_name} exists, skipping.") - return - - def _filter_cut(cut): - text = cut.supervisions[0].text - duration = cut.supervisions[0].duration - chinese = [] - english = [] - - # only contains chinese and space and alphabets - clean_chars = [] - for x in text: - if is_hangul(x): - logging.warning(f"Delete cut with text containing Korean : {text}") - return False - if is_japanese(x): - logging.warning(f"Delete cut with text containing Japanese : {text}") - return False - if is_chinese(x): - chinese.append(x) - clean_chars.append(x) - if is_alphabet(x): - english.append(x) - clean_chars.append(x) - if x == " ": - clean_chars.append(x) - if len(english) + len(chinese) == 0: - logging.warning(f"Delete cut with text has no valid chars : {text}") - return False - - words = tokenize_by_CJK_char("".join(clean_chars)) - for i in range(len(words) - 10): - if words[i : i + 10].count(words[i]) == 10: - logging.warning(f"Delete cut with text with too much repeats : {text}") - return False - # word speed, 20 - 600 / minute - if duration < len(words) / 600 * 60 or duration > len(words) / 20 * 60: - logging.warning( - f"Delete cut with audio text mismatch, duration : {duration}s, " - f"words : {len(words)}, text : {text}" - ) - return False - return True - - try: - cut_set = load_manifest_lazy(input_dir / file_name) - cut_set = cut_set.filter(_filter_cut) - cut_set.to_file(output_dir / file_name) - except Exception as e: - logging.error(f"Manifest {file_name} failed with error: {e}") - os.remove(str(output_dir / file_name)) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - - input_dir = Path(args.source_dir) - output_dir = Path(args.dest_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - cut_files = glob.glob(f"{args.source_dir}/emilia_cuts_{args.subset}.*.jsonl.gz") - - with Pool(max_workers=args.jobs) as pool: - futures = [ - pool.submit( - preprocess_emilia, filename.split("/")[-1], input_dir, output_dir - ) - for filename in cut_files - ] - for f in futures: - f.result() - f.done() - logging.info("Processing done.") diff --git a/egs/zipvoice/local/tokenizer.py b/egs/zipvoice/local/tokenizer.py deleted file mode 120000 index 024e340cc..000000000 --- a/egs/zipvoice/local/tokenizer.py +++ /dev/null @@ -1 +0,0 @@ -../zipvoice/tokenizer.py \ No newline at end of file diff --git a/egs/zipvoice/local/validate_manifest.py b/egs/zipvoice/local/validate_manifest.py deleted file mode 100644 index 68159ae03..000000000 --- a/egs/zipvoice/local/validate_manifest.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script checks the following assumptions of the generated manifest: - -- Single supervision per cut - -We will add more checks later if needed. - -Usage example: - - python3 ./local/validate_manifest.py \ - ./data/spectrogram/ljspeech_cuts_all.jsonl.gz - -""" - -import argparse -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest_lazy -from lhotse.dataset.speech_synthesis import validate_for_tts - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "manifest", - type=Path, - help="Path to the manifest file", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - manifest = args.manifest - logging.info(f"Validating {manifest}") - - assert manifest.is_file(), f"{manifest} does not exist" - cut_set = load_manifest_lazy(manifest) - assert isinstance(cut_set, CutSet), type(cut_set) - - validate_for_tts(cut_set) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/zipvoice/requirements.txt b/egs/zipvoice/requirements.txt deleted file mode 100644 index cbbe860a5..000000000 --- a/egs/zipvoice/requirements.txt +++ /dev/null @@ -1,17 +0,0 @@ ---find-links https://k2-fsa.github.io/icefall/piper_phonemize.html - -torch -torchaudio -huggingface_hub -lhotse -safetensors -vocos - -# Normalization -cn2an -inflect - -# Tokenization -jieba -piper_phonemize -pypinyin diff --git a/egs/zipvoice/scripts/evaluate.sh b/egs/zipvoice/scripts/evaluate.sh deleted file mode 100644 index fbc0eed9a..000000000 --- a/egs/zipvoice/scripts/evaluate.sh +++ /dev/null @@ -1,102 +0,0 @@ -export CUDA_VISIBLE_DEVICES="0" -export PYTHONWARNINGS=ignore -export PYTHONPATH=../../:$PYTHONPATH - -# Uncomment this if you have trouble connecting to HuggingFace -# export HF_ENDPOINT=https://hf-mirror.com - -start_stage=1 -end_stage=3 - -# Models used for SIM-o evaluation. -# SV model wavlm_large_finetune.pth is downloaded from https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification -# SSL model wavlm_large.pt is downloaded from https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt -sv_model_path=model/UniSpeech/wavlm_large_finetune.pth -wavlm_model_path=model/s3prl/wavlm_large.pt - -# Models used for UTMOS evaluation. -# wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt -# wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt -utmos_model_path=model/huggingface/utmos/utmos.pt -wav2vec_model_path=model/huggingface/utmos/wav2vec_small.pt - - -if [ $start_stage -le 1 ] && [ $end_stage -ge 1 ]; then - - echo "=====Evaluate for Seed-TTS test-en=======" - test_list=testset/test_seedtts_en.tsv - wav_path=results/zipvoice_seedtts_en - - echo $wav_path - echo "-----Computing SIM-o-----" - python3 local/evaluate_sim.py \ - --sv-model-path ${sv_model_path} \ - --ssl-model-path ${wavlm_model_path} \ - --eval-path ${wav_path} \ - --test-list ${test_list} - - echo "-----Computing WER-----" - python3 local/evaluate_wer_seedtts.py \ - --test-list ${test_list} \ - --wav-path ${wav_path} \ - --lang "en" - - echo "-----Computing UTSMOS-----" - python3 local/evaluate_utmos.py \ - --wav-path ${wav_path} \ - --utmos-model-path ${utmos_model_path} \ - --ssl-model-path ${wav2vec_model_path} - -fi - -if [ $start_stage -le 2 ] && [ $end_stage -ge 2 ]; then - echo "=====Evaluate for Seed-TTS test-zh=======" - test_list=testset/test_seedtts_zh.tsv - wav_path=results/zipvoice_seedtts_zh - - echo $wav_path - echo "-----Computing SIM-o-----" - python3 local/evaluate_sim.py \ - --sv-model-path ${sv_model_path} \ - --ssl-model-path ${wavlm_model_path} \ - --eval-path ${wav_path} \ - --test-list ${test_list} - - echo "-----Computing WER-----" - python3 local/evaluate_wer_seedtts.py \ - --test-list ${test_list} \ - --wav-path ${wav_path} \ - --lang "zh" - - echo "-----Computing UTSMOS-----" - python3 local/evaluate_utmos.py \ - --wav-path ${wav_path} \ - --utmos-model-path ${utmos_model_path} \ - --ssl-model-path ${wav2vec_model_path} -fi - -if [ $start_stage -le 3 ] && [ $end_stage -ge 3 ]; then - echo "=====Evaluate for Librispeech test-clean=======" - test_list=testset/test_librispeech_pc_test_clean.tsv - wav_path=results/zipvoice_librispeech_test_clean - - echo $wav_path - echo "-----Computing SIM-o-----" - python3 local/evaluate_sim.py \ - --sv-model-path ${sv_model_path} \ - --ssl-model-path ${wavlm_model_path} \ - --eval-path ${wav_path} \ - --test-list ${test_list} - - echo "-----Computing WER-----" - python3 local/evaluate_wer_hubert.py \ - --test-list ${test_list} \ - --wav-path ${wav_path} \ - - echo "-----Computing UTSMOS-----" - python3 local/evaluate_utmos.py \ - --wav-path ${wav_path} \ - --utmos-model-path ${utmos_model_path} \ - --ssl-model-path ${wav2vec_model_path} - -fi \ No newline at end of file diff --git a/egs/zipvoice/scripts/prepare_emilia.sh b/egs/zipvoice/scripts/prepare_emilia.sh deleted file mode 100755 index bf19ed1a5..000000000 --- a/egs/zipvoice/scripts/prepare_emilia.sh +++ /dev/null @@ -1,126 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=0 -stop_stage=5 -sampling_rate=24000 -nj=32 - -dl_dir=$PWD/download - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # Your download directory should look like this: - # - # download/Amphion___Emilia - # ├── metafile.yaml - # ├── raw - # │ ├── DE - # │ ├── EN - # │ ├── FR - # │ ├── JA - # │ ├── KO - # │ ├── openemilia_45batches.tar.gz - # │ ├── openemilia_all.tar.gz - # │ └── ZH - # └── README.md - - if [ ! -d $dl_dir/Amphion___Emilia/raw ]; then - log "Please refer https://openxlab.org.cn/datasets/Amphion/Emilia to download the dataset." - exit(-1) - fi - -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare emilia manifests (EN and ZH only)" - # We assume that you have downloaded the Emilia corpus - # to $dl_dir/Amphion___Emilia - # see stage 0 for the directory structure - mkdir -p data/manifests - if [ ! -e data/manifests/.emilia.done ]; then - lhotse prepare emilia --lang en --num-jobs ${nj} $dl_dir/Amphion___Emilia data/manifests - lhotse prepare emilia --lang zh --num-jobs ${nj} $dl_dir/Amphion___Emilia data/manifests - touch data/manifests/.emilia.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Preprocess Emilia dataset, mainly for cleaning" - mkdir -p data/manifests/splits_raw - if [ ! -e data/manifests/split_raw/.emilia.split.done ]; then - lhotse split-lazy data/manifests/emilia_cuts_EN.jsonl.gz data/manifests/splits_raw 10000 - lhotse split-lazy data/manifests/emilia_cuts_ZH.jsonl.gz data/manifests/splits_raw 10000 - touch data/manifests/splits_raw/.emilia.split.done - fi - - mkdir -p data/manifests/splits - - if [ ! -e data/manifests/splits/.emilia.preprocess.done ]; then - python local/preprocess_emilia.py --subset EN - python local/preprocess_emilia.py --subset ZH - touch data/manifests/splits/.emilia.preprocess.done - fi - -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Extract Fbank for Emilia" - mkdir -p data/fbank/emilia_splits - if [ ! -e data/fbank/emilia_splits/.emilia.fbank.done ]; then - # You can speed up the extraction by distributing splits to multiple machines. - for subset in EN ZH; do - python local/compute_fbank.py \ - --source-dir data/manifests/splits \ - --dest-dir data/fbank/emilia_splits \ - --dataset emilia \ - --subset ${subset} \ - --splits-cuts 1 \ - --split-begin 0 \ - --split-end 2000 \ - --num-jobs ${nj} - done - touch data/fbank/emilia_splits/.emilia.fbank.done - fi - - if [ ! -e data/fbank/emilia_cuts_EN.jsonl.gz ]; then - log "Combining EN fbank cuts and spliting EN dev set" - gunzip -c data/fbank/emilia_splits/emilia_cuts_EN.*.jsonl.gz > data/fbank/emilia_cuts_EN.jsonl - head -n 1500 data/fbank/emilia_cuts_EN.jsonl | gzip -c > data/fbank/emilia_cuts_EN_dev.jsonl.gz - sed -i '1,1500d' data/fbank/emilia_cuts_EN.jsonl - gzip data/fbank/emilia_cuts_EN.jsonl - fi - - if [ ! -e data/fbank/emilia_cuts_ZH.jsonl.gz ]; then - log "Combining ZH fbank cuts and spliting ZH dev set" - gunzip -c data/fbank/emilia_splits/emilia_cuts_ZH.*.jsonl.gz > data/fbank/emilia_cuts_ZH.jsonl - head -n 1500 data/fbank/emilia_cuts_ZH.jsonl | gzip -c > data/fbank/emilia_cuts_ZH_dev.jsonl.gz - sed -i '1,1500d' data/fbank/emilia_cuts_ZH.jsonl - gzip data/fbank/emilia_cuts_ZH.jsonl - fi - -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Generate token file" - if [ ! -e data/tokens_emilia.txt ]; then - ./local/prepare_token_file_emilia.py --tokens data/tokens_emilia.txt - fi -fi diff --git a/egs/zipvoice/scripts/prepare_libritts.sh b/egs/zipvoice/scripts/prepare_libritts.sh deleted file mode 100755 index 6d643145e..000000000 --- a/egs/zipvoice/scripts/prepare_libritts.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=0 -stop_stage=5 -sampling_rate=24000 -nj=20 - -dl_dir=$PWD/download - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/LibriTTS, - # you can create a symlink - # - # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS - # - if [ ! -d $dl_dir/LibriTTS ]; then - lhotse download libritts $dl_dir - fi - -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare LibriTTS manifest" - # We assume that you have downloaded the LibriTTS corpus - # to $dl_dir/LibriTTS - mkdir -p data/manifests_libritts - if [ ! -e data/manifests_libritts/.libritts.done ]; then - lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests - touch data/manifests/.libritts.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Compute Fbank for LibriTTS" - mkdir -p data/fbank - - if [ ! -e data/fbank/.libritts.done ]; then - for subset in train-clean-100 train-clean-360 train-other-500 dev-clean test-clean; do - python local/compute_fbank.py \ - --source-dir data/manifests \ - --dest-dir data/fbank \ - --dataset libritts \ - --subset ${subset} \ - --sampling-rate $sampling_rate \ - --num-jobs ${nj} - done - touch data/fbank/.libritts.done - fi - - # Here we shuffle and combine the train-clean-100, train-clean-360 and - # train-other-500 together to form the training set. - if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then - cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \ - <(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \ - <(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \ - shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz - fi - - if [ ! -f data/fbank/libritts_cuts_train-clean-460.jsonl.gz ]; then - cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \ - <(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) | \ - shuf | gzip -c > data/fbank/libritts_cuts_train-clean-460.jsonl.gz - fi - - if [ ! -e data/fbank/.libritts-validated.done ]; then - log "Validating data/fbank for LibriTTS" - ./local/validate_manifest.py \ - data/fbank/libritts_cuts_train-all-shuf.jsonl.gz - touch data/fbank/.libritts-validated.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Generate token file" - if [ ! -e data/tokens_libritts.txt ]; then - ./local/prepare_token_file_libritts.py --tokens data/tokens_libritts.txt - fi -fi diff --git a/egs/zipvoice/zipvoice/checkpoint.py b/egs/zipvoice/zipvoice/checkpoint.py deleted file mode 100644 index e3acd57dd..000000000 --- a/egs/zipvoice/zipvoice/checkpoint.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -from pathlib import Path -from typing import Any, Dict, Optional, Union - -import torch -import torch.nn as nn -from lhotse.dataset.sampling.base import CutSampler -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer - -# use duck typing for LRScheduler since we have different possibilities, see -# our class LRScheduler. -LRSchedulerType = object - - -def save_checkpoint( - filename: Path, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - model_ema: Optional[nn.Module] = None, - params: Optional[Dict[str, Any]] = None, - optimizer: Optional[Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, - sampler: Optional[CutSampler] = None, - rank: int = 0, -) -> None: - """Save training information to a file. - - Args: - filename: - The checkpoint filename. - model: - The model to be saved. We only save its `state_dict()`. - model_avg: - The stored model averaged from the start of training. - model_ema: - The EMA version of model. - params: - User defined parameters, e.g., epoch, loss. - optimizer: - The optimizer to be saved. We only save its `state_dict()`. - scheduler: - The scheduler to be saved. We only save its `state_dict()`. - scalar: - The GradScaler to be saved. We only save its `state_dict()`. - sampler: - The sampler used in the labeled training dataset. We only - save its `state_dict()`. - rank: - Used in DDP. We save checkpoint only for the node whose - rank is 0. - Returns: - Return None. - """ - if rank != 0: - return - - logging.info(f"Saving checkpoint to {filename}") - - if isinstance(model, DDP): - model = model.module - - checkpoint = { - "model": model.state_dict(), - "optimizer": optimizer.state_dict() if optimizer is not None else None, - "scheduler": scheduler.state_dict() if scheduler is not None else None, - "grad_scaler": scaler.state_dict() if scaler is not None else None, - "sampler": sampler.state_dict() if sampler is not None else None, - } - - if model_avg is not None: - checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict() - if model_ema is not None: - checkpoint["model_ema"] = model_ema.to(torch.float32).state_dict() - - if params: - for k, v in params.items(): - assert k not in checkpoint - checkpoint[k] = v - - torch.save(checkpoint, filename) - - -def load_checkpoint( - filename: Path, - model: Optional[nn.Module] = None, - model_avg: Optional[nn.Module] = None, - model_ema: Optional[nn.Module] = None, - strict: bool = False, -) -> Dict[str, Any]: - logging.info(f"Loading checkpoint from {filename}") - checkpoint = torch.load(filename, map_location="cpu") - - if model is not None: - - if next(iter(checkpoint["model"])).startswith("module."): - logging.info("Loading checkpoint saved by DDP") - - dst_state_dict = model.state_dict() - src_state_dict = checkpoint["model"] - for key in dst_state_dict.keys(): - src_key = "{}.{}".format("module", key) - dst_state_dict[key] = src_state_dict.pop(src_key) - assert len(src_state_dict) == 0 - model.load_state_dict(dst_state_dict, strict=strict) - else: - logging.info("Loading checkpoint") - model.load_state_dict(checkpoint["model"], strict=strict) - - checkpoint.pop("model") - - if model_avg is not None and "model_avg" in checkpoint: - logging.info("Loading averaged model") - model_avg.load_state_dict(checkpoint["model_avg"], strict=strict) - checkpoint.pop("model_avg") - - if model_ema is not None and "model_ema" in checkpoint: - logging.info("Loading ema model") - model_ema.load_state_dict(checkpoint["model_ema"], strict=strict) - checkpoint.pop("model_ema") - - return checkpoint diff --git a/egs/zipvoice/zipvoice/feature.py b/egs/zipvoice/zipvoice/feature.py deleted file mode 100644 index e7d484d10..000000000 --- a/egs/zipvoice/zipvoice/feature.py +++ /dev/null @@ -1,135 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Union - -import numpy as np -import torch -import torch.nn as nn -import torchaudio -from lhotse.features.base import FeatureExtractor, register_extractor -from lhotse.utils import Seconds, compute_num_frames - - -class MelSpectrogramFeatures(nn.Module): - def __init__( - self, - sampling_rate=24000, - n_mels=100, - n_fft=1024, - hop_length=256, - ): - super().__init__() - - self.mel_spec = torchaudio.transforms.MelSpectrogram( - sample_rate=sampling_rate, - n_fft=n_fft, - hop_length=hop_length, - n_mels=n_mels, - center=True, - power=1, - ) - - def forward(self, inp): - assert len(inp.shape) == 2 - - mel = self.mel_spec(inp) - logmel = mel.clamp(min=1e-7).log() - return logmel - - -@dataclass -class TorchAudioFbankConfig: - sampling_rate: int - n_mels: int - n_fft: int - hop_length: int - - -@register_extractor -class TorchAudioFbank(FeatureExtractor): - - name = "TorchAudioFbank" - config_type = TorchAudioFbankConfig - - def __init__(self, config): - super().__init__(config=config) - - def _feature_fn(self, sample): - fbank = MelSpectrogramFeatures( - sampling_rate=self.config.sampling_rate, - n_mels=self.config.n_mels, - n_fft=self.config.n_fft, - hop_length=self.config.hop_length, - ) - - return fbank(sample) - - @property - def device(self) -> Union[str, torch.device]: - return self.config.device - - def feature_dim(self, sampling_rate: int) -> int: - return self.config.n_mels - - def extract( - self, - samples: Union[np.ndarray, torch.Tensor], - sampling_rate: int, - ) -> Union[np.ndarray, torch.Tensor]: - # Check for sampling rate compatibility. - expected_sr = self.config.sampling_rate - assert sampling_rate == expected_sr, ( - f"Mismatched sampling rate: extractor expects {expected_sr}, " - f"got {sampling_rate}" - ) - is_numpy = False - if not isinstance(samples, torch.Tensor): - samples = torch.from_numpy(samples) - is_numpy = True - - if len(samples.shape) == 1: - samples = samples.unsqueeze(0) - assert samples.ndim == 2, samples.shape - assert samples.shape[0] == 1, samples.shape - - mel = self._feature_fn(samples).squeeze().t() - - assert mel.ndim == 2, mel.shape - assert mel.shape[1] == self.config.n_mels, mel.shape - - num_frames = compute_num_frames( - samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate - ) - - if mel.shape[0] > num_frames: - mel = mel[:num_frames] - elif mel.shape[0] < num_frames: - mel = mel.unsqueeze(0) - mel = torch.nn.functional.pad( - mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" - ).squeeze(0) - - if is_numpy: - return mel.cpu().numpy() - else: - return mel - - @property - def frame_shift(self) -> Seconds: - return self.config.hop_length / self.config.sampling_rate diff --git a/egs/zipvoice/zipvoice/generate_averaged_model.py b/egs/zipvoice/zipvoice/generate_averaged_model.py deleted file mode 100644 index e1b7ca7c6..000000000 --- a/egs/zipvoice/zipvoice/generate_averaged_model.py +++ /dev/null @@ -1,209 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -This script loads checkpoints and averages them. - -(1) Average ZipVoice models before distill: - python3 ./zipvoice/generate_averaged_model.py \ - --epoch 11 \ - --avg 4 \ - --distill 0 \ - --token-file data/tokens_emilia.txt \ - --exp-dir ./zipvoice/exp_zipvoice - - It will generate a file `epoch-11-avg-14.pt` in the given `exp_dir`. - You can later load it by `torch.load("epoch-11-avg-4.pt")`. - -(2) Average ZipVoice-Distill models (the first stage model): - - python3 ./zipvoice/generate_averaged_model.py \ - --iter 60000 \ - --avg 7 \ - --distill 1 \ - --token-file data/tokens_emilia.txt \ - --exp-dir ./zipvoice/exp_zipvoice_distill_1stage -""" - -import argparse -from pathlib import Path - -import torch -from model import get_distill_model, get_model -from tokenizer import TokenizerEmilia, TokenizerLibriTTS -from train_flow import add_model_arguments, get_params - -from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=11, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=4, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' or --iter", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipvoice/exp_zipvoice", - help="The experiment dir", - ) - - parser.add_argument( - "--distill", - type=str2bool, - default=False, - help="Whether to use distill model. ", - ) - - parser.add_argument( - "--dataset", - type=str, - default="emilia", - choices=["emilia", "libritts"], - help="The used training dataset for the model to inference", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - if params.dataset == "emilia": - tokenizer = TokenizerEmilia( - token_file=params.token_file, token_type=params.token_type - ) - elif params.dataset == "libritts": - tokenizer = TokenizerLibriTTS( - token_file=params.token_file, token_type=params.token_type - ) - - params.vocab_size = tokenizer.vocab_size - params.pad_id = tokenizer.pad_id - - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - print("Script started") - - params.device = torch.device("cpu") - print(f"Device: {params.device}") - - print("About to create model") - if params.distill: - model = get_distill_model(params) - else: - model = get_model(params) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - print( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(params.device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=params.device, - ), - strict=True, - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - print( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(params.device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=params.device, - ), - strict=True, - ) - if params.iter > 0: - filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt" - else: - filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save({"model": model.state_dict()}, filename) - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - print("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/zipvoice/zipvoice/infer.py b/egs/zipvoice/zipvoice/infer.py deleted file mode 100644 index 2819d3c85..000000000 --- a/egs/zipvoice/zipvoice/infer.py +++ /dev/null @@ -1,586 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Wei Kang -# Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script loads checkpoints to generate waveforms. -This script is supposed to be used with the model trained by yourself. -If you want to use the pre-trained checkpoints provided by us, please refer to zipvoice_infer.py. - -(1) Usage with a pre-trained checkpoint: - - (a) ZipVoice model before distill: - python3 zipvoice/infer.py \ - --checkpoint zipvoice/exp_zipvoice/epoch-11-avg-4.pt \ - --distill 0 \ - --token-file "data/tokens_emilia.txt" \ - --test-list test.tsv \ - --res-dir results/test \ - --num-step 16 \ - --guidance-scale 1 - - (b) ZipVoice-Distill: - python3 zipvoice/infer.py \ - --checkpoint zipvoice/exp_zipvoice_distill/checkpoint-2000.pt \ - --distill 1 \ - --token-file "data/tokens_emilia.txt" \ - --test-list test.tsv \ - --res-dir results/test_distill \ - --num-step 8 \ - --guidance-scale 3 - -(2) Usage with a directory of checkpoints (may requires checkpoint averaging): - - (a) ZipVoice model before distill: - python3 flow_match/infer.py \ - --exp-dir zipvoice/exp_zipvoice \ - --epoch 11 \ - --avg 4 \ - --distill 0 \ - --token-file "data/tokens_emilia.txt" \ - --test-list test.tsv \ - --res-dir results \ - --num-step 16 \ - --guidance-scale 1 - - (b) ZipVoice-Distill: - python3 flow_match/infer.py \ - --exp-dir zipvoice/exp_zipvoice_distill/ \ - --iter 2000 \ - --avg 0 \ - --distill 1 \ - --token-file "data/tokens_emilia.txt" \ - --test-list test.tsv \ - --res-dir results \ - --num-step 8 \ - --guidance-scale 3 -""" - - -import argparse -import datetime as dt -import logging -import os -from pathlib import Path -from typing import Optional - -import numpy as np -import soundfile as sf -import torch -import torch.nn as nn -import torchaudio -from checkpoint import load_checkpoint -from feature import TorchAudioFbank, TorchAudioFbankConfig -from lhotse.utils import fix_random_seed -from model import get_distill_model, get_model -from tokenizer import TokenizerEmilia, TokenizerLibriTTS -from train_flow import add_model_arguments, get_params -from vocos import Vocos - -from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints -from icefall.utils import AttributeDict, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - default=None, - help="The checkpoint for inference. " - "If it is None, will use checkpoints under exp_dir", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipvoice/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--epoch", - type=int, - default=0, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=4, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' or '--iter', avg=0 means no avg", - ) - - parser.add_argument( - "--vocoder-path", - type=str, - default=None, - help="The local vocos vocoder path, downloaded from huggingface, " - "will download the vocodoer from huggingface if it is None.", - ) - - parser.add_argument( - "--distill", - type=str2bool, - default=False, - help="Whether it is a distilled TTS model.", - ) - - parser.add_argument( - "--test-list", - type=str, - default=None, - help="The list of prompt speech, prompt_transcription, " - "and text to synthesize in the format of " - "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.", - ) - - parser.add_argument( - "--res-dir", - type=str, - default="results", - help="Path name of the generated wavs dir", - ) - - parser.add_argument( - "--dataset", - type=str, - default="emilia", - choices=["emilia", "libritts"], - help="The used training dataset for the model to inference", - ) - - parser.add_argument( - "--guidance-scale", - type=float, - default=1.0, - help="The scale of classifier-free guidance during inference.", - ) - - parser.add_argument( - "--num-step", - type=int, - default=16, - help="The number of sampling steps.", - ) - - parser.add_argument( - "--feat-scale", - type=float, - default=0.1, - help="The scale factor of fbank feature", - ) - - parser.add_argument( - "--speed", - type=float, - default=1.0, - help="Control speech speed, 1.0 means normal, >1.0 means speed up", - ) - - parser.add_argument( - "--t-shift", - type=float, - default=0.5, - help="Shift t to smaller ones if t_shift < 1.0", - ) - - parser.add_argument( - "--target-rms", - type=float, - default=0.1, - help="Target speech normalization rms value", - ) - - parser.add_argument( - "--seed", - type=int, - default=666, - help="Random seed", - ) - - add_model_arguments(parser) - - return parser - - -def get_vocoder(vocos_local_path: Optional[str] = None): - if vocos_local_path: - vocos_local_path = "model/huggingface/vocos-mel-24khz/" - vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") - state_dict = torch.load( - f"{vocos_local_path}/pytorch_model.bin", - weights_only=True, - map_location="cpu", - ) - vocoder.load_state_dict(state_dict) - else: - vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz") - return vocoder - - -def generate_sentence( - save_path: str, - prompt_text: str, - prompt_wav: str, - text: str, - model: nn.Module, - vocoder: nn.Module, - tokenizer: TokenizerEmilia, - feature_extractor: TorchAudioFbank, - device: torch.device, - num_step: int = 16, - guidance_scale: float = 1.0, - speed: float = 1.0, - t_shift: float = 0.5, - target_rms: float = 0.1, - feat_scale: float = 0.1, - sampling_rate: int = 24000, -): - """ - Generate waveform of a text based on a given prompt - waveform and its transcription. - - Args: - save_path (str): Path to save the generated wav. - prompt_text (str): Transcription of the prompt wav. - prompt_wav (str): Path to the prompt wav file. - text (str): Text to be synthesized into a waveform. - model (nn.Module): The model used for generation. - vocoder (nn.Module): The vocoder used to convert features to waveforms. - tokenizer (TokenizerEmilia): The tokenizer used to convert text to tokens. - feature_extractor (TorchAudioFbank): The feature extractor used to - extract acoustic features. - device (torch.device): The device on which computations are performed. - num_step (int, optional): Number of steps for decoding. Defaults to 16. - guidance_scale (float, optional): Scale for classifier-free guidance. - Defaults to 1.0. - speed (float, optional): Speed control. Defaults to 1.0. - t_shift (float, optional): Time shift. Defaults to 0.5. - target_rms (float, optional): Target RMS for waveform normalization. - Defaults to 0.1. - feat_scale (float, optional): Scale for features. - Defaults to 0.1. - sampling_rate (int, optional): Sampling rate for the waveform. - Defaults to 24000. - Returns: - metrics (dict): Dictionary containing time and real-time - factor metrics for processing. - """ - # Convert text to tokens - tokens = tokenizer.texts_to_token_ids([text]) - prompt_tokens = tokenizer.texts_to_token_ids([prompt_text]) - - # Load and preprocess prompt wav - prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav) - prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav))) - if prompt_rms < target_rms: - prompt_wav = prompt_wav * target_rms / prompt_rms - - if prompt_sampling_rate != sampling_rate: - resampler = torchaudio.transforms.Resample( - orig_freq=prompt_sampling_rate, new_freq=sampling_rate - ) - prompt_wav = resampler(prompt_wav) - - # Extract features from prompt wav - prompt_features = feature_extractor.extract( - prompt_wav, sampling_rate=sampling_rate - ).to(device) - prompt_features = prompt_features.unsqueeze(0) * feat_scale - prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device) - - # Start timing - start_t = dt.datetime.now() - - # Generate features - ( - pred_features, - pred_features_lens, - pred_prompt_features, - pred_prompt_features_lens, - ) = model.sample( - tokens=tokens, - prompt_tokens=prompt_tokens, - prompt_features=prompt_features, - prompt_features_lens=prompt_features_lens, - speed=speed, - t_shift=t_shift, - duration="predict", - num_step=num_step, - guidance_scale=guidance_scale, - ) - - # Postprocess predicted features - pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T) - - # Start vocoder processing - start_vocoder_t = dt.datetime.now() - wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1) - - # Calculate processing times and real-time factors - t = (dt.datetime.now() - start_t).total_seconds() - t_no_vocoder = (start_vocoder_t - start_t).total_seconds() - t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds() - wav_seconds = wav.shape[-1] / sampling_rate - rtf = t / wav_seconds - rtf_no_vocoder = t_no_vocoder / wav_seconds - rtf_vocoder = t_vocoder / wav_seconds - metrics = { - "t": t, - "t_no_vocoder": t_no_vocoder, - "t_vocoder": t_vocoder, - "wav_seconds": wav_seconds, - "rtf": rtf, - "rtf_no_vocoder": rtf_no_vocoder, - "rtf_vocoder": rtf_vocoder, - } - - # Adjust wav volume if necessary - if prompt_rms < target_rms: - wav = wav * prompt_rms / target_rms - wav = wav[0].cpu().numpy() - sf.write(save_path, wav, sampling_rate) - - return metrics - - -def generate( - params: AttributeDict, - model: nn.Module, - vocoder: nn.Module, - tokenizer: TokenizerEmilia, -): - total_t = [] - total_t_no_vocoder = [] - total_t_vocoder = [] - total_wav_seconds = [] - - config = TorchAudioFbankConfig( - sampling_rate=params.sampling_rate, - n_mels=100, - n_fft=1024, - hop_length=256, - ) - feature_extractor = TorchAudioFbank(config) - - with open(params.test_list, "r") as fr: - lines = fr.readlines() - - for i, line in enumerate(lines): - wav_name, prompt_text, prompt_wav, text = line.strip().split("\t") - save_path = f"{params.wav_dir}/{wav_name}.wav" - metrics = generate_sentence( - save_path=save_path, - prompt_text=prompt_text, - prompt_wav=prompt_wav, - text=text, - model=model, - vocoder=vocoder, - tokenizer=tokenizer, - feature_extractor=feature_extractor, - device=params.device, - num_step=params.num_step, - guidance_scale=params.guidance_scale, - speed=params.speed, - t_shift=params.t_shift, - target_rms=params.target_rms, - feat_scale=params.feat_scale, - sampling_rate=params.sampling_rate, - ) - print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}") - total_t.append(metrics["t"]) - total_t_no_vocoder.append(metrics["t_no_vocoder"]) - total_t_vocoder.append(metrics["t_vocoder"]) - total_wav_seconds.append(metrics["wav_seconds"]) - - print(f"Average RTF: " f"{np.sum(total_t)/np.sum(total_wav_seconds):.4f}") - print( - f"Average RTF w/o vocoder: " - f"{np.sum(total_t_no_vocoder)/np.sum(total_wav_seconds):.4f}" - ) - print( - f"Average RTF vocoder: " - f"{np.sum(total_t_vocoder)/np.sum(total_wav_seconds):.4f}" - ) - - -@torch.inference_mode() -def main(): - parser = get_parser() - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - if params.iter > 0: - params.suffix = ( - f"wavs-iter-{params.iter}-avg" - f"-{params.avg}-step-{params.num_step}-scale-{params.guidance_scale}" - ) - elif params.epoch > 0: - params.suffix = ( - f"wavs-epoch-{params.epoch}-avg" - f"-{params.avg}-step-{params.num_step}-scale-{params.guidance_scale}" - ) - else: - params.suffix = "wavs" - - setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") - logging.info("Decoding started") - - if torch.cuda.is_available(): - params.device = torch.device("cuda", 0) - else: - params.device = torch.device("cpu") - - logging.info(f"Device: {params.device}") - - if params.dataset == "emilia": - tokenizer = TokenizerEmilia( - token_file=params.token_file, token_type=params.token_type - ) - elif params.dataset == "libritts": - tokenizer = TokenizerLibriTTS( - token_file=params.token_file, token_type=params.token_type - ) - - params.vocab_size = tokenizer.vocab_size - params.pad_id = tokenizer.pad_id - - logging.info(params) - fix_random_seed(params.seed) - - logging.info("About to create model") - if params.distill: - model = get_distill_model(params) - else: - model = get_model(params) - - if params.checkpoint: - load_checkpoint(params.checkpoint, model, strict=True) - else: - if params.avg == 0: - if params.iter > 0: - load_checkpoint( - f"{params.exp_dir}/checkpoint-{params.iter}.pt", model, strict=True - ) - else: - load_checkpoint( - f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True - ) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(params.device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=params.device, - ), - strict=True, - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(params.device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=params.device, - ), - strict=True, - ) - - model = model.to(params.device) - model.eval() - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - vocoder = get_vocoder(params.vocoder_path) - vocoder = vocoder.to(params.device) - vocoder.eval() - num_param = sum([p.numel() for p in vocoder.parameters()]) - logging.info(f"Number of vocoder parameters: {num_param}") - - params.wav_dir = f"{params.res_dir}/{params.suffix}" - os.makedirs(params.wav_dir, exist_ok=True) - - assert ( - params.test_list is not None - ), "Please provide --test-list for speech synthesize." - generate( - params=params, - model=model, - vocoder=vocoder, - tokenizer=tokenizer, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - main() diff --git a/egs/zipvoice/zipvoice/model.py b/egs/zipvoice/zipvoice/model.py deleted file mode 100644 index 25c7973b2..000000000 --- a/egs/zipvoice/zipvoice/model.py +++ /dev/null @@ -1,578 +0,0 @@ -# Copyright 2024 Xiaomi Corp. (authors: Wei Kang -# Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional - -import torch -import torch.nn as nn -from scaling import ScheduledFloat -from solver import EulerSolver -from torch.nn.parallel import DistributedDataParallel as DDP -from utils import ( - AttributeDict, - condition_time_mask, - get_tokens_index, - make_pad_mask, - pad_labels, - prepare_avg_tokens_durations, - to_int_tuple, -) -from zipformer import TTSZipformer - - -def get_model(params: AttributeDict) -> nn.Module: - """Get the normal TTS model.""" - - fm_decoder = get_fm_decoder_model(params) - text_encoder = get_text_encoder_model(params) - - model = TtsModel( - fm_decoder=fm_decoder, - text_encoder=text_encoder, - text_embed_dim=params.text_embed_dim, - feat_dim=params.feat_dim, - vocab_size=params.vocab_size, - pad_id=params.pad_id, - ) - return model - - -def get_distill_model(params: AttributeDict) -> nn.Module: - """Get the distillation TTS model.""" - - fm_decoder = get_fm_decoder_model(params, distill=True) - text_encoder = get_text_encoder_model(params) - - model = DistillTTSModelTrainWrapper( - fm_decoder=fm_decoder, - text_encoder=text_encoder, - text_embed_dim=params.text_embed_dim, - feat_dim=params.feat_dim, - vocab_size=params.vocab_size, - pad_id=params.pad_id, - ) - return model - - -def get_fm_decoder_model(params: AttributeDict, distill: bool = False) -> nn.Module: - """Get the Zipformer-based FM decoder model.""" - - encoder = TTSZipformer( - in_dim=params.feat_dim * 3, - out_dim=params.feat_dim, - downsampling_factor=to_int_tuple(params.fm_decoder_downsampling_factor), - num_encoder_layers=to_int_tuple(params.fm_decoder_num_layers), - cnn_module_kernel=to_int_tuple(params.fm_decoder_cnn_module_kernel), - encoder_dim=params.fm_decoder_dim, - feedforward_dim=params.fm_decoder_feedforward_dim, - num_heads=params.fm_decoder_num_heads, - query_head_dim=params.query_head_dim, - pos_head_dim=params.pos_head_dim, - value_head_dim=params.value_head_dim, - pos_dim=params.pos_dim, - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - use_time_embed=True, - time_embed_dim=192, - use_guidance_scale_embed=distill, - ) - return encoder - - -def get_text_encoder_model(params: AttributeDict) -> nn.Module: - """Get the Zipformer-based text encoder model.""" - - encoder = TTSZipformer( - in_dim=params.text_embed_dim, - out_dim=params.feat_dim, - downsampling_factor=to_int_tuple(params.text_encoder_downsampling_factor), - num_encoder_layers=to_int_tuple(params.text_encoder_num_layers), - cnn_module_kernel=to_int_tuple(params.text_encoder_cnn_module_kernel), - encoder_dim=params.text_encoder_dim, - feedforward_dim=params.text_encoder_feedforward_dim, - num_heads=params.text_encoder_num_heads, - query_head_dim=params.query_head_dim, - pos_head_dim=params.pos_head_dim, - value_head_dim=params.value_head_dim, - pos_dim=params.pos_dim, - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - use_time_embed=False, - ) - return encoder - - -class TtsModel(nn.Module): - """The normal TTS model.""" - - def __init__( - self, - fm_decoder: nn.Module, - text_encoder: nn.Module, - text_embed_dim: int, - feat_dim: int, - vocab_size: int, - pad_id: int = 0, - ): - """ - Args: - fm_decoder: the flow-matching encoder model, inputs are the - input condition embeddings and noisy acoustic features, - outputs are better acoustic features. - text_encoder: the text encoder model. input are text - embeddings, output are contextualized text embeddings. - text_embed_dim: dimension of text embedding. - feat_dim: dimension of acoustic features. - vocab_size: vocabulary size. - pad_id: padding id. - """ - super().__init__() - - self.feat_dim = feat_dim - self.text_embed_dim = text_embed_dim - self.pad_id = pad_id - - self.fm_decoder = fm_decoder - - self.text_encoder = text_encoder - - self.embed = nn.Embedding(vocab_size, text_embed_dim) - - self.distill = False - - def forward_fm_decoder( - self, - t: torch.Tensor, - xt: torch.Tensor, - text_condition: torch.Tensor, - speech_condition: torch.Tensor, - padding_mask: Optional[torch.Tensor] = None, - guidance_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Compute velocity. - Args: - t: A tensor of shape (N, 1, 1) or a tensor of a float, - in the range of (0, 1). - xt: the input of the current timestep, including condition - embeddings and noisy acoustic features. - text_condition: the text condition embeddings, with the - shape (batch, seq_len, emb_dim). - speech_condition: the speech condition embeddings, with the - shape (batch, seq_len, emb_dim). - padding_mask: The mask for padding, True means masked - position, with the shape (N, T). - guidance_scale: The guidance scale in classifier-free guidance, - which is a tensor of shape (N, 1, 1) or a tensor of a float. - - Returns: - predicted velocity, with the shape (batch, seq_len, emb_dim). - """ - assert t.dim() in (0, 3) - # Handle t with the shape (N, 1, 1): - # squeeze the last dimension if it's size is 1. - while t.dim() > 1 and t.size(-1) == 1: - t = t.squeeze(-1) - if guidance_scale is not None: - while guidance_scale.dim() > 1 and guidance_scale.size(-1) == 1: - guidance_scale = guidance_scale.squeeze(-1) - # Handle t with a single value: expand to the size of batch size. - if t.dim() == 0: - t = t.repeat(xt.shape[0]) - if guidance_scale is not None and guidance_scale.dim() == 0: - guidance_scale = guidance_scale.repeat(xt.shape[0]) - - xt = torch.cat([xt, text_condition, speech_condition], dim=2) - vt = self.fm_decoder( - x=xt, t=t, padding_mask=padding_mask, guidance_scale=guidance_scale - ) - return vt - - def forward_text_embed( - self, - tokens: List[List[int]], - ): - """ - Get the text embeddings. - Args: - tokens: a list of list of token ids. - Returns: - embed: the text embeddings, shape (batch, seq_len, emb_dim). - tokens_lens: the length of each token sequence, shape (batch,). - """ - device = ( - self.device if isinstance(self, DDP) else next(self.parameters()).device - ) - tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S) - embed = self.embed(tokens_padded) # (B, S, C) - tokens_lens = torch.tensor( - [len(token) for token in tokens], dtype=torch.int64, device=device - ) - tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S) - - embed = self.text_encoder( - x=embed, t=None, padding_mask=tokens_padding_mask - ) # (B, S, C) - return embed, tokens_lens - - def forward_text_condition( - self, - embed: torch.Tensor, - tokens_lens: torch.Tensor, - features_lens: torch.Tensor, - ): - """ - Get the text condition with the same length of the acoustic feature. - Args: - embed: the text embeddings, shape (batch, token_seq_len, emb_dim). - tokens_lens: the length of each token sequence, shape (batch,). - features_lens: the length of each acoustic feature sequence, - shape (batch,). - Returns: - text_condition: the text condition, shape - (batch, feature_seq_len, emb_dim). - padding_mask: the padding mask of text condition, shape - (batch, feature_seq_len). - """ - - num_frames = int(features_lens.max()) - - padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T) - - tokens_durations = prepare_avg_tokens_durations(features_lens, tokens_lens) - - tokens_index = get_tokens_index(tokens_durations, num_frames).to( - embed.device - ) # (B, T) - - text_condition = torch.gather( - embed, - dim=1, - index=tokens_index.unsqueeze(-1).expand( - embed.size(0), num_frames, embed.size(-1) - ), - ) # (B, T, F) - return text_condition, padding_mask - - def forward_text_train( - self, - tokens: List[List[int]], - features_lens: torch.Tensor, - ): - """ - Process text for training, given text tokens and real feature lengths. - """ - embed, tokens_lens = self.forward_text_embed(tokens) - text_condition, padding_mask = self.forward_text_condition( - embed, tokens_lens, features_lens - ) - return ( - text_condition, - padding_mask, - ) - - def forward_text_inference_gt_duration( - self, - tokens: List[List[int]], - features_lens: torch.Tensor, - prompt_tokens: List[List[int]], - prompt_features_lens: torch.Tensor, - ): - """ - Process text for inference, given text tokens, real feature lengths and prompts. - """ - tokens = [ - prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens) - ] - features_lens = prompt_features_lens + features_lens - embed, tokens_lens = self.forward_text_embed(tokens) - text_condition, padding_mask = self.forward_text_condition( - embed, tokens_lens, features_lens - ) - return text_condition, padding_mask - - def forward_text_inference_ratio_duration( - self, - tokens: List[List[int]], - prompt_tokens: List[List[int]], - prompt_features_lens: torch.Tensor, - speed: float, - ): - """ - Process text for inference, given text tokens and prompts, - feature lengths are predicted with the ratio of token numbers. - """ - device = ( - self.device if isinstance(self, DDP) else next(self.parameters()).device - ) - - cat_tokens = [ - prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens) - ] - - prompt_tokens_lens = torch.tensor( - [len(token) for token in prompt_tokens], dtype=torch.int64, device=device - ) - - cat_embed, cat_tokens_lens = self.forward_text_embed(cat_tokens) - - features_lens = torch.ceil( - (prompt_features_lens / prompt_tokens_lens * cat_tokens_lens / speed) - ).to(dtype=torch.int64) - - text_condition, padding_mask = self.forward_text_condition( - cat_embed, cat_tokens_lens, features_lens - ) - return text_condition, padding_mask - - def forward( - self, - tokens: List[List[int]], - features: torch.Tensor, - features_lens: torch.Tensor, - noise: torch.Tensor, - t: torch.Tensor, - condition_drop_ratio: float = 0.0, - ) -> torch.Tensor: - """Forward pass of the model for training. - Args: - tokens: a list of list of token ids. - features: the acoustic features, with the shape (batch, seq_len, feat_dim). - features_lens: the length of each acoustic feature sequence, shape (batch,). - noise: the intitial noise, with the shape (batch, seq_len, feat_dim). - t: the time step, with the shape (batch, 1, 1). - condition_drop_ratio: the ratio of dropped text condition. - Returns: - fm_loss: the flow-matching loss. - """ - - (text_condition, padding_mask,) = self.forward_text_train( - tokens=tokens, - features_lens=features_lens, - ) - - speech_condition_mask = condition_time_mask( - features_lens=features_lens, - mask_percent=(0.7, 1.0), - max_len=features.size(1), - ) - speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features) - - if condition_drop_ratio > 0.0: - drop_mask = ( - torch.rand(text_condition.size(0), 1, 1).to(text_condition.device) - > condition_drop_ratio - ) - text_condition = text_condition * drop_mask - - xt = features * t + noise * (1 - t) - ut = features - noise # (B, T, F) - - vt = self.forward_fm_decoder( - t=t, - xt=xt, - text_condition=text_condition, - speech_condition=speech_condition, - padding_mask=padding_mask, - ) - - loss_mask = speech_condition_mask & (~padding_mask) - fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2) - - return fm_loss - - def sample( - self, - tokens: List[List[int]], - prompt_tokens: List[List[int]], - prompt_features: torch.Tensor, - prompt_features_lens: torch.Tensor, - features_lens: Optional[torch.Tensor] = None, - speed: float = 1.0, - t_shift: float = 1.0, - duration: str = "predict", - num_step: int = 5, - guidance_scale: float = 0.5, - ) -> torch.Tensor: - """ - Generate acoustic features, given text tokens, prompts feature - and prompt transcription's text tokens. - Args: - tokens: a list of list of text tokens. - prompt_tokens: a list of list of prompt tokens. - prompt_features: the prompt feature with the shape - (batch_size, seq_len, feat_dim). - prompt_features_lens: the length of each prompt feature, - with the shape (batch_size,). - features_lens: the length of the predicted eature, with the - shape (batch_size,). It is used only when duration is "real". - duration: "real" or "predict". If "real", the predicted - feature length is given by features_lens. - num_step: the number of steps to use in the ODE solver. - guidance_scale: the guidance scale for classifier-free guidance. - distill: whether to use the distillation model for sampling. - """ - - assert duration in ["real", "predict"] - - if duration == "predict": - ( - text_condition, - padding_mask, - ) = self.forward_text_inference_ratio_duration( - tokens=tokens, - prompt_tokens=prompt_tokens, - prompt_features_lens=prompt_features_lens, - speed=speed, - ) - else: - assert features_lens is not None - text_condition, padding_mask = self.forward_text_inference_gt_duration( - tokens=tokens, - features_lens=features_lens, - prompt_tokens=prompt_tokens, - prompt_features_lens=prompt_features_lens, - ) - batch_size, num_frames, _ = text_condition.shape - - speech_condition = torch.nn.functional.pad( - prompt_features, (0, 0, 0, num_frames - prompt_features.size(1)) - ) # (B, T, F) - - # False means speech condition positions. - speech_condition_mask = make_pad_mask(prompt_features_lens, num_frames) - speech_condition = torch.where( - speech_condition_mask.unsqueeze(-1), - torch.zeros_like(speech_condition), - speech_condition, - ) - - x0 = torch.randn( - batch_size, num_frames, self.feat_dim, device=text_condition.device - ) - solver = EulerSolver(self, distill=self.distill, func_name="forward_fm_decoder") - - x1 = solver.sample( - x=x0, - text_condition=text_condition, - speech_condition=speech_condition, - padding_mask=padding_mask, - num_step=num_step, - guidance_scale=guidance_scale, - t_shift=t_shift, - ) - x1_wo_prompt_lens = (~padding_mask).sum(-1) - prompt_features_lens - x1_prompt = torch.zeros( - x1.size(0), prompt_features_lens.max(), x1.size(2), device=x1.device - ) - x1_wo_prompt = torch.zeros( - x1.size(0), x1_wo_prompt_lens.max(), x1.size(2), device=x1.device - ) - for i in range(x1.size(0)): - x1_wo_prompt[i, : x1_wo_prompt_lens[i], :] = x1[ - i, - prompt_features_lens[i] : prompt_features_lens[i] - + x1_wo_prompt_lens[i], - ] - x1_prompt[i, : prompt_features_lens[i], :] = x1[ - i, : prompt_features_lens[i] - ] - - return x1_wo_prompt, x1_wo_prompt_lens, x1_prompt, prompt_features_lens - - def sample_intermediate( - self, - tokens: List[List[int]], - features: torch.Tensor, - features_lens: torch.Tensor, - noise: torch.Tensor, - speech_condition_mask: torch.Tensor, - t_start: torch.Tensor, - t_end: torch.Tensor, - num_step: int = 1, - guidance_scale: torch.Tensor = None, - ) -> torch.Tensor: - """ - Generate acoustic features in intermediate timesteps. - Args: - tokens: List of list of token ids. - features: The acoustic features, with the shape (batch, seq_len, feat_dim). - features_lens: The length of each acoustic feature sequence, - with the shape (batch,). - noise: The initial noise, with the shape (batch, seq_len, feat_dim). - speech_condition_mask: The mask for speech condition, True means - non-condition positions, with the shape (batch, seq_len). - t_start: The start timestep, with the shape (batch, 1, 1). - t_end: The end timestep, with the shape (batch, 1, 1). - num_step: The number of steps for sampling. - guidance_scale: The scale for classifier-free guidance inference, - with the shape (batch, 1, 1). - distill: Whether to use distillation model. - """ - (text_condition, padding_mask,) = self.forward_text_train( - tokens=tokens, - features_lens=features_lens, - ) - - speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features) - - solver = EulerSolver(self, distill=self.distill, func_name="forward_fm_decoder") - - x_t_end = solver.sample( - x=noise, - text_condition=text_condition, - speech_condition=speech_condition, - padding_mask=padding_mask, - num_step=num_step, - guidance_scale=guidance_scale, - t_start=t_start, - t_end=t_end, - ) - x_t_end_lens = (~padding_mask).sum(-1) - return x_t_end, x_t_end_lens - - -class DistillTTSModelTrainWrapper(TtsModel): - """Wrapper for training the distilled TTS model.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.distill = True - - def forward( - self, - tokens: List[List[int]], - features: torch.Tensor, - features_lens: torch.Tensor, - noise: torch.Tensor, - speech_condition_mask: torch.Tensor, - t_start: torch.Tensor, - t_end: torch.Tensor, - num_step: int = 1, - guidance_scale: torch.Tensor = None, - ) -> torch.Tensor: - - return self.sample_intermediate( - tokens=tokens, - features=features, - features_lens=features_lens, - noise=noise, - speech_condition_mask=speech_condition_mask, - t_start=t_start, - t_end=t_end, - num_step=num_step, - guidance_scale=guidance_scale, - ) diff --git a/egs/zipvoice/zipvoice/optim.py b/egs/zipvoice/zipvoice/optim.py deleted file mode 100644 index daf17556a..000000000 --- a/egs/zipvoice/zipvoice/optim.py +++ /dev/null @@ -1,1256 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import logging -import random -from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union - -import torch -from lhotse.utils import fix_random_seed -from torch import Tensor -from torch.optim import Optimizer - - -class BatchedOptimizer(Optimizer): - """ - This class adds to class Optimizer the capability to optimize parameters in batches: - it will stack the parameters and their grads for you so the optimizer can work - on tensors with an extra leading dimension. This is intended for speed with GPUs, - as it reduces the number of kernels launched in the optimizer. - - Args: - params: - """ - - def __init__(self, params, defaults): - super(BatchedOptimizer, self).__init__(params, defaults) - - @contextlib.contextmanager - def batched_params(self, param_group, group_params_names): - """ - This function returns (technically, yields) a list of - of tuples (p, state), where - p is a `fake` parameter that is stacked (over axis 0) from real parameters - that share the same shape, and its gradient is also stacked; - `state` is the state corresponding to this batch of parameters - (it will be physically located in the "state" for one of the real - parameters, the last one that has any particular shape and dtype). - - This function is decorated as a context manager so that it can - write parameters back to their "real" locations. - - The idea is, instead of doing: - - for p in group["params"]: - state = self.state[p] - ... - - you can do: - - with self.batched_params(group["params"]) as batches: - for p, state, p_names in batches: - ... - - - Args: - group: a parameter group, which is a list of parameters; should be - one of self.param_groups. - group_params_names: name for each parameter in group, - which is List[str]. - """ - batches = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter - batches_names = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str - - assert len(param_group) == len(group_params_names) - for p, named_p in zip(param_group, group_params_names): - key = (str(p.dtype), *p.shape) - batches[key].append(p) - batches_names[key].append(named_p) - - batches_names_keys = list(batches_names.keys()) - sorted_idx = sorted( - range(len(batches_names)), key=lambda i: batches_names_keys[i] - ) - batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] - batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] - - stacked_params_dict = dict() - - # turn batches into a list, in deterministic order. - # tuples will contain tuples of (stacked_param, state, stacked_params_names), - # one for each batch in `batches`. - tuples = [] - - for batch, batch_names in zip(batches, batches_names): - p = batch[0] - # we arbitrarily store the state in the - # state corresponding to the 1st parameter in the - # group. class Optimizer will take care of saving/loading state. - state = self.state[p] - p_stacked = torch.stack(batch) - grad = torch.stack( - [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] - ) - p_stacked.grad = grad - stacked_params_dict[key] = p_stacked - tuples.append((p_stacked, state, batch_names)) - - yield tuples # <-- calling code will do the actual optimization here! - - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): - for i, p in enumerate(batch): # batch is list of Parameter - p.copy_(stacked_params[i]) - - -def basic_step(group, p, state, grad): - # computes basic Adam update using beta2 (dividing by gradient stddev) only. no momentum yet. - lr = group["lr"] - if p.numel() == p.shape[0]: - lr = lr * group["scalar_lr_scale"] - beta2 = group["betas"][1] - eps = group["eps"] - # p shape: (batch_size,) or (batch_size, 1, [1,..]) - try: - exp_avg_sq = state[ - "exp_avg_sq" - ] # shape: (batch_size,) or (batch_size, 1, [1,..]) - except KeyError: - exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float) - state["exp_avg_sq"] = exp_avg_sq - - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # bias_correction2 is like in Adam. - # slower update at the start will help stability anyway. - bias_correction2 = 1 - beta2 ** (state["step"] + 1) - if bias_correction2 < 0.99: - # note: not in-place. - exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) - denom = exp_avg_sq.sqrt().add_(eps) - - return -lr * grad / denom - - -def scaling_step(group, p, state, grad): - delta = basic_step(group, p, state, grad) - if p.numel() == p.shape[0]: - return delta # there is no scaling for scalar parameters. (p.shape[0] is the batch of parameters.) - - step = state["step"] - size_update_period = group["size_update_period"] - - try: - param_rms = state["param_rms"] - scale_grads = state["scale_grads"] - scale_exp_avg_sq = state["scale_exp_avg_sq"] - except KeyError: - # we know p.ndim > 1 because we'd have returned above if not, so don't worry - # about the speial case of dim=[] that pytorch treats inconsistently. - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - param_rms = param_rms.to(torch.float) - scale_exp_avg_sq = torch.zeros_like(param_rms) - scale_grads = torch.zeros( - size_update_period, *param_rms.shape, dtype=torch.float, device=p.device - ) - state["param_rms"] = param_rms - state["scale_grads"] = scale_grads - state["scale_exp_avg_sq"] = scale_exp_avg_sq - - # on every step, update the gradient w.r.t. the scale of the parameter, we - # store these as a batch and periodically update the size (for speed only, to - # avoid too many operations). - scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) - - # periodically recompute the value of param_rms. - if step % size_update_period == size_update_period - 1: - param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()) - - param_min_rms = group["param_min_rms"] - - # scale the step size by param_rms. This is the most important "scaling" part of - # ScaledAdam - delta *= param_rms.clamp(min=param_min_rms) - - if step % size_update_period == size_update_period - 1 and step > 0: - # This block updates the size of parameter by adding a step ("delta") value in - # the direction of either shrinking or growing it. - beta2 = group["betas"][1] - size_lr = group["lr"] * group["scalar_lr_scale"] - param_max_rms = group["param_max_rms"] - eps = group["eps"] - batch_size = p.shape[0] - # correct beta2 for the size update period: we will have - # faster decay at this level. - beta2_corr = beta2**size_update_period - scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) - - # The 1st time we reach here is when size_step == 1. - size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step - - denom = scale_exp_avg_sq.sqrt() + eps - - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) - - is_too_small = param_rms < param_min_rms - - # when the param gets too small, just don't shrink it any further. - scale_step.masked_fill_(is_too_small, 0.0) - - # The following may help prevent instability: don't allow the scale step to be too large in - # either direction. - scale_step.clamp_(min=-0.1, max=0.1) - - # and ensure the parameter rms after update never exceeds param_max_rms. - # We have to look at the trained model for parameters at or around the - # param_max_rms, because sometimes they can indicate a problem with the - # topology or settings. - scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) - - delta.add_(p * scale_step) - - return delta - - -def momentum_step(group, p, state, grad): - delta = scaling_step(group, p, state, grad) - beta1 = group["betas"][0] - try: - stored_delta = state["delta"] - except KeyError: - stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) - state["delta"] = stored_delta - stored_delta.mul_(beta1) - stored_delta.add_(delta, alpha=(1 - beta1)) - # we don't bother doing the "bias correction" part of Adam for beta1 because this is just - # an edge effect that affects the first 10 or so batches; and the effect of not doing it - # is just to do a slower update for the first few batches, which will help stability. - return stored_delta - - -class ScaledAdam(BatchedOptimizer): - """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) - - - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - Unlike common optimizers, which accept model.parameters() or groups of parameters(), - this optimizer could accept model.named_parameters() or groups of named_parameters(). - See comments of function _get_names_of_parameters for its 4 possible cases. - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period - """ - - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, - ): - - defaults = dict( - lr=lr, - clipping_scale=clipping_scale, - betas=betas, - scalar_lr_scale=scalar_lr_scale, - eps=eps, - param_min_rms=param_min_rms, - param_max_rms=param_max_rms, - scalar_max=scalar_max, - size_update_period=size_update_period, - clipping_update_period=clipping_update_period, - ) - - # If params only contains parameters or group of parameters, - # i.e when parameter names are not given, - # this flag will be set to False in funciton _get_names_of_parameters. - self.show_dominant_parameters = True - param_groups, parameters_names = self._get_names_of_parameters(params) - super(ScaledAdam, self).__init__(param_groups, defaults) - assert len(self.param_groups) == len(parameters_names) - self.parameters_names = parameters_names - - def _get_names_of_parameters( - self, params_or_named_params - ) -> Tuple[List[Dict], List[List[str]]]: - """ - Args: - params_or_named_params: according to the way ScaledAdam is initialized in train.py, - this argument could be one of following 4 cases, - case 1, a generator of parameter, e.g.: - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0) - - case 2, a list of parameter groups with different config, e.g.: - model_param_groups = [ - {'params': model.encoder.parameters(), 'lr': 0.05}, - {'params': model.decoder.parameters(), 'lr': 0.01}, - {'params': model.joiner.parameters(), 'lr': 0.03}, - ] - optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0) - - case 3, a generator of named_parameter, e.g.: - optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0) - - case 4, a list of named_parameter groups with different config, e.g.: - model_named_param_groups = [ - {'named_params': model.encoder.named_parameters(), 'lr': 0.05}, - {'named_params': model.decoder.named_parameters(), 'lr': 0.01}, - {'named_params': model.joiner.named_parameters(), 'lr': 0.03}, - ] - optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0) - - For case 1 and case 2, input params is used to initialize the underlying torch.optimizer. - For case 3 and case 4, firstly, names and params are extracted from input named_params, - then, these extracted params are used to initialize the underlying torch.optimizer, - and these extracted names are mainly used by function - `_show_gradient_dominating_parameter` - - Returns: - Returns a tuple containing 2 elements: - - `param_groups` with type List[Dict], each Dict element is a parameter group. - An example of `param_groups` could be: - [ - {'params': `one iterable of Parameter`, 'lr': 0.05}, - {'params': `another iterable of Parameter`, 'lr': 0.08}, - {'params': `a third iterable of Parameter`, 'lr': 0.1}, - ] - - `param_gruops_names` with type List[List[str]], - each `List[str]` is for a group['params'] in param_groups, - and each `str` is the name of a parameter. - A dummy name "foo" is related to each parameter, - if input are params without names, i.e. case 1 or case 2. - """ - # variable naming convention in this function: - # p is short for param. - # np is short for named_param. - # p_or_np is short for param_or_named_param. - # cur is short for current. - # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}. - # groups is a List[group] - - iterable_or_groups = list(params_or_named_params) - if len(iterable_or_groups) == 0: - raise ValueError("optimizer got an empty parameter list") - - # The first value of returned tuple. A list of dicts containing at - # least 'params' as a key. - param_groups = [] - - # The second value of returned tuple, - # a List[List[str]], each sub-List is for a group. - param_groups_names = [] - - if not isinstance(iterable_or_groups[0], dict): - # case 1 or case 3, - # the input is an iterable of parameter or named parameter. - param_iterable_cur_group = [] - param_names_cur_group = [] - for p_or_np in iterable_or_groups: - if isinstance(p_or_np, tuple): - # case 3 - name, param = p_or_np - else: - # case 1 - assert isinstance(p_or_np, torch.Tensor) - param = p_or_np - # Assign a dummy name as a placeholder - name = "foo" - self.show_dominant_parameters = False - param_iterable_cur_group.append(param) - param_names_cur_group.append(name) - param_groups.append({"params": param_iterable_cur_group}) - param_groups_names.append(param_names_cur_group) - else: - # case 2 or case 4 - # the input is groups of parameter or named parameter. - for cur_group in iterable_or_groups: - if "named_params" in cur_group: - name_list = [x[0] for x in cur_group["named_params"]] - p_list = [x[1] for x in cur_group["named_params"]] - del cur_group["named_params"] - cur_group["params"] = p_list - else: - assert "params" in cur_group - name_list = ["foo" for _ in cur_group["params"]] - param_groups.append(cur_group) - param_groups_names.append(name_list) - - return param_groups, param_groups_names - - def __setstate__(self, state): - super(ScaledAdam, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - batch = True - - for group, group_params_names in zip(self.param_groups, self.parameters_names): - - with self.batched_params(group["params"], group_params_names) as batches: - - # batches is list of pairs (stacked_param, state). stacked_param is like - # a regular parameter, and will have a .grad, but the 1st dim corresponds to - # a stacking dim, it is not a real dim. - - if ( - len(batches[0][1]) == 0 - ): # if len(first state) == 0: not yet initialized - clipping_scale = 1 - else: - clipping_scale = self._get_clipping_scale(group, batches) - - for p, state, _ in batches: - # Perform optimization step. - # grad is not going to be None, we handled that when creating the batches. - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" - ) - - try: - cur_step = state["step"] - except KeyError: - state["step"] = 0 - cur_step = 0 - - grad = ( - p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale) - ) - p += momentum_step(group, p.detach(), state, grad) - - if p.numel() == p.shape[0]: # scalar parameter - scalar_max = group["scalar_max"] - p.clamp_(min=-scalar_max, max=scalar_max) - - state["step"] = cur_step + 1 - - return loss - - def _get_clipping_scale( - self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] - ) -> float: - """ - Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients - by this amount before applying the rest of the update. - - Args: - group: the parameter group, an item in self.param_groups - tuples: a list of tuples of (param, state, param_names) - where param is a batched set of parameters, - with a .grad (1st dim is batch dim) - and state is the state-dict where optimization parameters are kept. - param_names is a List[str] while each str is name for a parameter - in batched set of parameters "param". - """ - assert len(tuples) >= 1 - clipping_scale = group["clipping_scale"] - (first_p, first_state, _) = tuples[0] - step = first_state["step"] - if clipping_scale is None or step == 0: - # no clipping. return early on step == 0 because the other - # parameters' state won't have been initialized yet. - return 1.0 - clipping_update_period = group["clipping_update_period"] - scalar_lr_scale = group["scalar_lr_scale"] - - tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" - ) - if p.numel() == p.shape[0]: # a batch of scalars - tot_sumsq += (grad**2).sum() * ( - scalar_lr_scale**2 - ) # sum() to change shape [1] to [] - else: - tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() - - tot_norm = tot_sumsq.sqrt() - if "model_norms" not in first_state: - first_state["model_norms"] = torch.zeros( - clipping_update_period, device=p.device - ) - first_state["model_norms"][step % clipping_update_period] = tot_norm - - irregular_estimate_steps = [ - i for i in [10, 20, 40] if i < clipping_update_period - ] - if step % clipping_update_period == 0 or step in irregular_estimate_steps: - # Print some stats. - # We don't reach here if step == 0 because we would have returned - # above. - sorted_norms = first_state["model_norms"].sort()[0].to("cpu") - if step in irregular_estimate_steps: - sorted_norms = sorted_norms[-step:] - num_norms = sorted_norms.numel() - quartiles = [] - for n in range(0, 5): - index = min(num_norms - 1, (num_norms // 4) * n) - quartiles.append(sorted_norms[index].item()) - - median = quartiles[2] - if median - median != 0: - raise RuntimeError("Too many grads were not finite") - threshold = clipping_scale * median - if step in irregular_estimate_steps: - # use larger thresholds on first few steps of estimating threshold, - # as norm may be changing rapidly. - threshold = threshold * 2.0 - first_state["model_norm_threshold"] = threshold - percent_clipped = ( - first_state["num_clipped"] * 100.0 / num_norms - if "num_clipped" in first_state - else 0.0 - ) - first_state["num_clipped"] = 0 - quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.warning( - f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" - ) - - try: - model_norm_threshold = first_state["model_norm_threshold"] - except KeyError: - return 1.0 # threshold has not yet been set. - - ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) - if ans != ans: # e.g. ans is nan - ans = 0.0 - if ans < 1.0: - first_state["num_clipped"] += 1 - if ans < 0.5: - logging.warning( - f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" - ) - if self.show_dominant_parameters: - assert p.shape[0] == len(param_names) - self._show_gradient_dominating_parameter( - tuples, tot_sumsq, group["scalar_lr_scale"] - ) - self._show_param_with_unusual_grad(tuples) - - if ans == 0.0: - for (p, state, param_names) in tuples: - p.grad.zero_() # get rid of infinity() - - return ans - - def _show_param_with_unusual_grad( - self, - tuples: List[Tuple[Tensor, dict, List[str]]], - ): - """ - Print information about parameter which has the largest ratio of grad-on-this-batch - divided by normal grad size. - tuples: a list of tuples of (param, state, param_names) - where param is a batched set of parameters, - with a .grad (1st dim is batch dim) - and state is the state-dict where optimization parameters are kept. - param_names is a List[str] while each str is name for a parameter - in batched set of parameters "param". - """ - largest_ratio = 0.0 - largest_name = "" - # ratios_names is a list of 3-tuples: (grad_ratio, param_name, tensor) - ratios_names = [] - for (p, state, batch_param_names) in tuples: - dims = list(range(1, p.ndim)) - - def mean(x): - # workaround for bad interface of torch's "mean" for when dims is the empty list. - if len(dims) > 0: - return x.mean(dim=dims) - else: - return x - - grad_ratio = ( - (mean(p.grad**2) / state["exp_avg_sq"].mean(dim=dims)) - .sqrt() - .to("cpu") - ) - - ratios_names += zip( - grad_ratio.tolist(), batch_param_names, p.grad.unbind(dim=0) - ) - - ratios_names = sorted(ratios_names, reverse=True) - ratios_names = ratios_names[:10] - ratios_names = [ - (ratio, name, largest_index(tensor)) - for (ratio, name, tensor) in ratios_names - ] - - logging.debug( - f"Parameters with most larger-than-usual grads, with ratios, are: {ratios_names}" - ) - - def _show_gradient_dominating_parameter( - self, - tuples: List[Tuple[Tensor, dict, List[str]]], - tot_sumsq: Tensor, - scalar_lr_scale: float, - ): - """ - Show information of parameter which dominates tot_sumsq. - - Args: - tuples: a list of tuples of (param, state, param_names) - where param is a batched set of parameters, - with a .grad (1st dim is batch dim) - and state is the state-dict where optimization parameters are kept. - param_names is a List[str] while each str is name for a parameter - in batched set of parameters "param". - tot_sumsq: sumsq of all parameters. Though it's could be calculated - from tuples, we still pass it to save some time. - """ - all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: - # p is a stacked batch parameters. - batch_grad = p.grad - if p.numel() == p.shape[0]: # a batch of scalars - # Dummy values used by following `zip` statement. - batch_rms_orig = torch.full( - p.shape, scalar_lr_scale, device=batch_grad.device - ) - else: - batch_rms_orig = state["param_rms"] - batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2 - if batch_grad.ndim > 1: - # need to guard it with if-statement because sum() sums over - # all dims if dim == (). - batch_sumsq_orig = batch_sumsq_orig.sum( - dim=list(range(1, batch_grad.ndim)) - ) - for name, sumsq_orig, rms, grad in zip( - batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad - ): - - proportion_orig = sumsq_orig / tot_sumsq - all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) - - sorted_by_proportion = { - k: v - for k, v in sorted( - all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True - ) - } - dominant_param_name = next(iter(sorted_by_proportion)) - ( - dominant_proportion, - dominant_sumsq, - dominant_rms, - dominant_grad, - ) = sorted_by_proportion[dominant_param_name] - logging.debug( - f"Parameter dominating tot_sumsq {dominant_param_name}" - f" with proportion {dominant_proportion:.2f}," - f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" - f"={dominant_sumsq:.3e}," - f" grad_sumsq={(dominant_grad**2).sum():.3e}," - f" orig_rms_sq={(dominant_rms**2).item():.3e}" - ) - - -def largest_index(x: Tensor): - x = x.contiguous() - argmax = x.abs().argmax().item() - return [(argmax // x.stride(i)) % x.size(i) for i in range(x.ndim)] - - -class LRScheduler(object): - """ - Base-class for learning rate schedulers where the learning-rate depends on both the - batch and the epoch. - """ - - def __init__(self, optimizer: Optimizer, verbose: bool = False): - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) - self.optimizer = optimizer - self.verbose = verbose - - for group in optimizer.param_groups: - group.setdefault("base_lr", group["lr"]) - - self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] - - self.epoch = 0 - self.batch = 0 - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - """ - return { - # the user might try to override the base_lr, so don't include this in the state. - # previously they were included. - # "base_lrs": self.base_lrs, - "epoch": self.epoch, - "batch": self.batch, - } - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # the things with base_lrs are a work-around for a previous problem - # where base_lrs were written with the state dict. - base_lrs = self.base_lrs - self.__dict__.update(state_dict) - self.base_lrs = base_lrs - - def get_last_lr(self) -> List[float]: - """Return last computed learning rate by current scheduler. Will be a list of float.""" - return self._last_lr - - def get_lr(self): - # Compute list of learning rates from self.epoch and self.batch and - # self.base_lrs; this must be overloaded by the user. - # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] - raise NotImplementedError - - def step_batch(self, batch: Optional[int] = None) -> None: - # Step the batch index, or just set it. If `batch` is specified, it - # must be the batch index from the start of training, i.e. summed over - # all epochs. - # You can call this in any order; if you don't provide 'batch', it should - # of course be called once per batch. - if batch is not None: - self.batch = batch - else: - self.batch = self.batch + 1 - self._set_lrs() - - def step_epoch(self, epoch: Optional[int] = None): - # Step the epoch index, or just set it. If you provide the 'epoch' arg, - # you should call this at the start of the epoch; if you don't provide the 'epoch' - # arg, you should call it at the end of the epoch. - if epoch is not None: - self.epoch = epoch - else: - self.epoch = self.epoch + 1 - self._set_lrs() - - def _set_lrs(self): - values = self.get_lr() - assert len(values) == len(self.optimizer.param_groups) - - for i, data in enumerate(zip(self.optimizer.param_groups, values)): - param_group, lr = data - param_group["lr"] = lr - self.print_lr(self.verbose, i, lr) - self._last_lr = [group["lr"] for group in self.optimizer.param_groups] - - def print_lr(self, is_verbose, group, lr): - """Display the current learning rate.""" - if is_verbose: - logging.warning( - f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" - f" of group {group} to {lr:.4e}." - ) - - -class Eden(LRScheduler): - """ - Eden scheduler. - The basic formula (before warmup) is: - lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * - (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup - where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches - and then stays constant at 1. - - If you don't have the concept of epochs, or one epoch takes a very long time, - you can replace the notion of 'epoch' with some measure of the amount of data - processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to - some measure representing "quite a lot of data": say, one fifth or one third - of an entire training run, but it doesn't matter much. You could also use - Eden2 which has only the notion of batches. - - We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam - - Args: - optimizer: the optimizer to change the learning rates on - lr_batches: the number of batches after which we start significantly - decreasing the learning rate, suggest 5000. - lr_epochs: the number of epochs after which we start significantly - decreasing the learning rate, suggest 6 if you plan to do e.g. - 20 to 40 epochs, but may need smaller number if dataset is huge - and you will do few epochs. - """ - - def __init__( - self, - optimizer: Optimizer, - lr_batches: Union[int, float], - lr_epochs: Union[int, float], - warmup_batches: Union[int, float] = 500.0, - warmup_start: float = 0.5, - verbose: bool = False, - ): - super(Eden, self).__init__(optimizer, verbose) - self.lr_batches = lr_batches - self.lr_epochs = lr_epochs - self.warmup_batches = warmup_batches - - assert 0.0 <= warmup_start <= 1.0, warmup_start - self.warmup_start = warmup_start - - def get_lr(self): - factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 - ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 - ) - warmup_factor = ( - 1.0 - if self.batch >= self.warmup_batches - else self.warmup_start - + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) - # else 0.5 + 0.5 * (self.batch / self.warmup_batches) - ) - - return [x * factor * warmup_factor for x in self.base_lrs] - - -class Eden2(LRScheduler): - """ - Eden2 scheduler, simpler than Eden because it does not use the notion of epoch, - only batches. - - The basic formula (before warmup) is: - lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup - - where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches - and then stays constant at 1. - - - E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam - - Args: - optimizer: the optimizer to change the learning rates on - lr_batches: the number of batches after which we start significantly - decreasing the learning rate, suggest 5000. - """ - - def __init__( - self, - optimizer: Optimizer, - lr_batches: Union[int, float], - warmup_batches: Union[int, float] = 500.0, - warmup_start: float = 0.5, - verbose: bool = False, - ): - super().__init__(optimizer, verbose) - self.lr_batches = lr_batches - self.warmup_batches = warmup_batches - - assert 0.0 <= warmup_start <= 1.0, warmup_start - self.warmup_start = warmup_start - - def get_lr(self): - factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 - ) ** -0.5 - warmup_factor = ( - 1.0 - if self.batch >= self.warmup_batches - else self.warmup_start - + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) - # else 0.5 + 0.5 * (self.batch / self.warmup_batches) - ) - - return [x * factor * warmup_factor for x in self.base_lrs] - - -class FixedLRScheduler(LRScheduler): - """ - Fixed learning rate scheduler. - - Args: - optimizer: the optimizer to change the learning rates on - """ - - def __init__( - self, - optimizer: Optimizer, - verbose: bool = False, - ): - super(FixedLRScheduler, self).__init__(optimizer, verbose) - - def get_lr(self): - - return [x for x in self.base_lrs] - - -def _test_eden(): - m = torch.nn.Linear(100, 100) - optim = ScaledAdam(m.parameters(), lr=0.03) - - scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) - - for epoch in range(10): - scheduler.step_epoch(epoch) # sets epoch to `epoch` - - for step in range(20): - x = torch.randn(200, 100).detach() - x.requires_grad = True - y = m(x) - dy = torch.randn(200, 100).detach() - f = (y * dy).sum() - f.backward() - - optim.step() - scheduler.step_batch() - optim.zero_grad() - - logging.info(f"last lr = {scheduler.get_last_lr()}") - logging.info(f"state dict = {scheduler.state_dict()}") - - -# This is included mostly as a baseline for ScaledAdam. -class Eve(Optimizer): - """ - Implements Eve algorithm. This is a modified version of AdamW with a special - way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular target_rms (default: 0.1). This is - for use with networks with 'scaled' versions of modules (see scaling.py), which - will be close to invariant to the absolute scale on the parameter matrix. - - The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. - The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - Eve is unpublished so far. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 3e-4; - this value means that the weight would decay significantly after - about 3k minibatches. Is not multiplied by learning rate, but - is conditional on RMS-value of parameter being > target_rms. - target_rms (float, optional): target root-mean-square value of - parameters, if they fall below this we will stop applying weight decay. - - - .. _Adam: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.98), - eps=1e-8, - weight_decay=1e-3, - target_rms=0.1, - ): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - if not 0 < target_rms <= 10.0: - raise ValueError("Invalid target_rms value: {}".format(target_rms)) - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - target_rms=target_rms, - ) - super(Eve, self).__init__(params, defaults) - - def __setstate__(self, state): - super(Eve, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - - # Perform optimization step - grad = p.grad - if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") - - state = self.state[p] - - # State initialization - if len(state) == 0: - state["step"] = 0 - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - - beta1, beta2 = group["betas"] - - state["step"] += 1 - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( - group["eps"] - ) - - step_size = group["lr"] / bias_correction1 - target_rms = group["target_rms"] - weight_decay = group["weight_decay"] - - if p.numel() > 1: - # avoid applying this weight-decay on "scaling factors" - # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) - p.mul_(1 - (weight_decay * is_above_target_rms)) - - p.addcdiv_(exp_avg, denom, value=-step_size) - - if random.random() < 0.0005: - step = (exp_avg / denom) * step_size - logging.info( - f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" - ) - - return loss - - -def _test_scaled_adam(hidden_dim: int): - import timeit - - from scaling import ScaledLinear - - E = 100 - B = 4 - T = 2 - logging.info("in test_eve_cain") - # device = torch.device('cuda') - device = torch.device("cpu") - dtype = torch.float32 - - fix_random_seed(42) - # these input_magnitudes and output_magnitudes are to test that - # Abel is working as we expect and is able to adjust scales of - # different dims differently. - input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - - for iter in [1, 0]: - fix_random_seed(42) - Linear = torch.nn.Linear if iter == 0 else ScaledLinear - - m = torch.nn.Sequential( - Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) - - train_pairs = [ - ( - 100.0 - * torch.randn(B, T, E, device=device, dtype=dtype) - * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, - ) - for _ in range(20) - ] - - if iter == 0: - optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: - optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0) - scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) - - start = timeit.default_timer() - avg_loss = 0.0 - for epoch in range(180): - scheduler.step_epoch() - # if epoch == 100 and iter in [2,3]: - # optim.reset_speedup() # check it doesn't crash. - - # if epoch == 130: - # opts = diagnostics.TensorDiagnosticOptions( - # 512 - # ) # allow 4 megabytes per sub-module - # diagnostic = diagnostics.attach_diagnostics(m, opts) - - for n, (x, y) in enumerate(train_pairs): - y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 - if epoch == 0 and n == 0: - avg_loss = loss.item() - else: - avg_loss = 0.98 * avg_loss + 0.02 * loss.item() - if n == 0 and epoch % 5 == 0: - # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) - lr = scheduler.get_last_lr()[0] - logging.info( - f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" - ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} - loss.log().backward() - optim.step() - optim.zero_grad() - scheduler.step_batch() - - # diagnostic.print_diagnostics() - - stop = timeit.default_timer() - logging.info(f"Iter={iter}, Time taken: {stop - start}") - - logging.info(f"last lr = {scheduler.get_last_lr()}") - # logging.info("state dict = ", scheduler.state_dict()) - # logging.info("optim state_dict = ", optim.state_dict()) - logging.info(f"input_magnitudes = {input_magnitudes}") - logging.info(f"output_magnitudes = {output_magnitudes}") - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - logging.getLogger().setLevel(logging.INFO) - import subprocess - - s = subprocess.check_output( - "git status -uno .; git log -1; git diff HEAD .", shell=True - ) - logging.info(s) - import sys - - if len(sys.argv) > 1: - hidden_dim = int(sys.argv[1]) - else: - hidden_dim = 200 - - _test_scaled_adam(hidden_dim) - _test_eden() diff --git a/egs/zipvoice/zipvoice/scaling.py b/egs/zipvoice/zipvoice/scaling.py deleted file mode 100644 index afe9ad468..000000000 --- a/egs/zipvoice/zipvoice/scaling.py +++ /dev/null @@ -1,1930 +0,0 @@ -# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -import math -import random -import sys -from typing import Optional, Tuple, Union - -try: - import k2 -except Exception as ex: - logging.warning( - "k2 is not installed correctly. Swoosh functions will fallback to " - "pytorch implementation." - ) - -import torch -import torch.nn as nn -from torch import Tensor - -custom_bwd = lambda func: torch.amp.custom_bwd(func, device_type="cuda") -custom_fwd = lambda func: torch.amp.custom_fwd(func, device_type="cuda") - - -def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: - max_value = torch.max(x, y) - diff = torch.abs(x - y) - return max_value + torch.log1p(torch.exp(-diff)) - - -# RuntimeError: Exporting the operator logaddexp to ONNX opset version -# 14 is not supported. Please feel free to request support or submit -# a pull request on PyTorch GitHub. -# -# The following function is to solve the above error when exporting -# models to ONNX via torch.jit.trace() -def logaddexp(x: Tensor, y: Tensor) -> Tensor: - # Caution(fangjun): Put torch.jit.is_scripting() before - # torch.onnx.is_in_onnx_export(); - # otherwise, it will cause errors for torch.jit.script(). - # - # torch.logaddexp() works for both torch.jit.script() and - # torch.jit.trace() but it causes errors for ONNX export. - # - if torch.jit.is_scripting(): - # Note: We cannot use torch.jit.is_tracing() here as it also - # matches torch.onnx.export(). - return torch.logaddexp(x, y) - elif torch.onnx.is_in_onnx_export(): - return logaddexp_onnx(x, y) - else: - # for torch.jit.trace() - return torch.logaddexp(x, y) - - -class PiecewiseLinear(object): - """ - Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with - the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] - respectively. - """ - - def __init__(self, *args): - assert len(args) >= 1, len(args) - if len(args) == 1 and isinstance(args[0], PiecewiseLinear): - self.pairs = list(args[0].pairs) - else: - self.pairs = [(float(x), float(y)) for x, y in args] - for x, y in self.pairs: - assert isinstance(x, (float, int)), type(x) - assert isinstance(y, (float, int)), type(y) - - for i in range(len(self.pairs) - 1): - assert self.pairs[i + 1][0] > self.pairs[i][0], ( - i, - self.pairs[i], - self.pairs[i + 1], - ) - - def __str__(self): - # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' - return f"PiecewiseLinear({str(self.pairs)[1:-1]})" - - def __call__(self, x): - if x <= self.pairs[0][0]: - return self.pairs[0][1] - elif x >= self.pairs[-1][0]: - return self.pairs[-1][1] - else: - cur_x, cur_y = self.pairs[0] - for i in range(1, len(self.pairs)): - next_x, next_y = self.pairs[i] - if x >= cur_x and x <= next_x: - return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) - cur_x, cur_y = next_x, next_y - assert False - - def __mul__(self, alpha): - return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) - - def __add__(self, x): - if isinstance(x, (float, int)): - return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) - s, x = self.get_common_basis(x) - return PiecewiseLinear( - *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] - ) - - def max(self, x): - if isinstance(x, (float, int)): - x = PiecewiseLinear((0, x)) - s, x = self.get_common_basis(x, include_crossings=True) - return PiecewiseLinear( - *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] - ) - - def min(self, x): - if isinstance(x, float) or isinstance(x, int): - x = PiecewiseLinear((0, x)) - s, x = self.get_common_basis(x, include_crossings=True) - return PiecewiseLinear( - *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] - ) - - def __eq__(self, other): - return self.pairs == other.pairs - - def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): - """ - Returns (self_mod, p_mod) which are equivalent piecewise linear - functions to self and p, but with the same x values. - - p: the other piecewise linear function - include_crossings: if true, include in the x values positions - where the functions indicate by this and p crosss. - """ - assert isinstance(p, PiecewiseLinear), type(p) - - # get sorted x-values without repetition. - x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) - y_vals1 = [self(x) for x in x_vals] - y_vals2 = [p(x) for x in x_vals] - - if include_crossings: - extra_x_vals = [] - for i in range(len(x_vals) - 1): - if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): - # if the two lines in this subsegment potentially cross each other.. - diff_cur = abs(y_vals1[i] - y_vals2[i]) - diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) - # `pos`, between 0 and 1, gives the relative x position, - # with 0 being x_vals[i] and 1 being x_vals[i+1]. - pos = diff_cur / (diff_cur + diff_next) - extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) - extra_x_vals.append(extra_x_val) - if len(extra_x_vals) > 0: - x_vals = sorted(set(x_vals + extra_x_vals)) - y_vals1 = [self(x) for x in x_vals] - y_vals2 = [p(x) for x in x_vals] - return ( - PiecewiseLinear(*zip(x_vals, y_vals1)), - PiecewiseLinear(*zip(x_vals, y_vals2)), - ) - - -class ScheduledFloat(torch.nn.Module): - """ - This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); - it does not have a working forward() function. You are supposed to cast it to float, as - in, float(parent_module.whatever), and use it as something like a dropout prob. - - It is a floating point value whose value changes depending on the batch count of the - training loop. It is a piecewise linear function where you specify the (x,y) pairs - in sorted order on x; x corresponds to the batch index. For batch-index values before the - first x or after the last x, we just use the first or last y value. - - Example: - self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) - - `default` is used when self.batch_count is not set or not in training mode or in - torch.jit scripting mode. - """ - - def __init__(self, *args, default: float = 0.0): - super().__init__() - # self.batch_count and self.name will be written to in the training loop. - self.batch_count = None - self.name = None - self.default = default - self.schedule = PiecewiseLinear(*args) - - def extra_repr(self) -> str: - return ( - f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" - ) - - def __float__(self): - batch_count = self.batch_count - if ( - batch_count is None - or not self.training - or torch.jit.is_scripting() - or torch.jit.is_tracing() - ): - return float(self.default) - else: - ans = self.schedule(self.batch_count) - if random.random() < 0.0002: - logging.debug( - f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}" - ) - return ans - - def __add__(self, x): - if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule + x, default=self.default) - else: - return ScheduledFloat( - self.schedule + x.schedule, default=self.default + x.default - ) - - def max(self, x): - if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule.max(x), default=self.default) - else: - return ScheduledFloat( - self.schedule.max(x.schedule), default=max(self.default, x.default) - ) - - -FloatLike = Union[float, ScheduledFloat] - - -def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: - """ - A randomized way of casting a floating point value to half precision. - """ - if x.dtype == torch.float16: - return x - x_abs = x.abs() - is_too_small = x_abs < min_abs - # for elements where is_too_small is true, random_val will contain +-min_abs with - # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, - # for those elements]. - random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) - return torch.where(is_too_small, random_val, x).to(torch.float16) - - -class CutoffEstimator: - """ - Estimates cutoffs of an arbitrary numerical quantity such that a specified - proportion of items will be above the cutoff on average. - - p is the proportion of items that should be above the cutoff. - """ - - def __init__(self, p: float): - self.p = p - # total count of items - self.count = 0 - # total count of items that were above the cutoff - self.count_above = 0 - # initial cutoff value - self.cutoff = 0 - - def __call__(self, x: float) -> bool: - """ - Returns true if x is above the cutoff. - """ - ans = x > self.cutoff - self.count += 1 - if ans: - self.count_above += 1 - cur_p = self.count_above / self.count - delta_p = cur_p - self.p - if (delta_p > 0) == ans: - q = abs(delta_p) - self.cutoff = x * q + self.cutoff * (1 - q) - return ans - - -class SoftmaxFunction(torch.autograd.Function): - """ - Tries to handle half-precision derivatives in a randomized way that should - be more accurate for training than the default behavior. - """ - - @staticmethod - def forward(ctx, x: Tensor, dim: int): - ans = x.softmax(dim=dim) - # if x dtype is float16, x.softmax() returns a float32 because - # (presumably) that op does not support float16, and autocast - # is enabled. - if torch.is_autocast_enabled(): - ans = ans.to(torch.float16) - ctx.save_for_backward(ans) - ctx.x_dtype = x.dtype - ctx.dim = dim - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor): - (ans,) = ctx.saved_tensors - with torch.amp.autocast("cuda", enabled=False): - ans_grad = ans_grad.to(torch.float32) - ans = ans.to(torch.float32) - x_grad = ans_grad * ans - x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) - return x_grad, None - - -def softmax(x: Tensor, dim: int): - if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing(): - return x.softmax(dim=dim) - - return SoftmaxFunction.apply(x, dim) - - -class MaxEigLimiterFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float, - ) -> Tensor: - ctx.channel_dim = channel_dim - ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) - return x - - @staticmethod - def backward(ctx, x_grad, *args): - with torch.enable_grad(): - (x_orig, coeffs, new_direction) = ctx.saved_tensors - x_orig.requires_grad = True - num_channels = x_orig.shape[ctx.channel_dim] - x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) - new_direction.requires_grad = False - x = x - x.mean(dim=0) - x_var = (x**2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. This is to be minimized. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - variance_proportion.backward() - x_orig_grad = x_orig.grad - x_extra_grad = ( - x_orig.grad - * ctx.grad_scale - * x_grad.norm() - / (x_orig_grad.norm() + 1.0e-20) - ) - return x_grad + x_extra_grad.detach(), None, None, None, None - - -class BiasNormFunction(torch.autograd.Function): - # This computes: - # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() - # return x * scales - # (after unsqueezing the bias), but it does it in a memory-efficient way so that - # it can just store the returned value (chances are, this will also be needed for - # some other reason, related to the next operation, so we can save memory). - @staticmethod - def forward( - ctx, - x: Tensor, - bias: Tensor, - log_scale: Tensor, - channel_dim: int, - store_output_for_backprop: bool, - ) -> Tensor: - assert bias.ndim == 1 - if channel_dim < 0: - channel_dim = channel_dim + x.ndim - ctx.store_output_for_backprop = store_output_for_backprop - ctx.channel_dim = channel_dim - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = ( - torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 - ) * log_scale.exp() - ans = x * scales - ctx.save_for_backward( - ans.detach() if store_output_for_backprop else x, - scales.detach(), - bias.detach(), - log_scale.detach(), - ) - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tensor: - ans_or_x, scales, bias, log_scale = ctx.saved_tensors - if ctx.store_output_for_backprop: - x = ans_or_x / scales - else: - x = ans_or_x - x = x.detach() - x.requires_grad = True - bias.requires_grad = True - log_scale.requires_grad = True - with torch.enable_grad(): - # recompute scales from x, bias and log_scale. - scales = ( - torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 - ) * log_scale.exp() - ans = x * scales - ans.backward(gradient=ans_grad) - return x.grad, bias.grad.flatten(), log_scale.grad, None, None - - -class BiasNorm(torch.nn.Module): - """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - Instead, we give the BiasNorm a trainable bias that it can use when - computing the scale for normalization. We also give it a (scalar) - trainable scale on the output. - - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interpreted as an offset from the input's ndim if negative. - This is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - log_scale: the initial log-scale that we multiply the output by; this - is learnable. - log_scale_min: FloatLike, minimum allowed value of log_scale - log_scale_max: FloatLike, maximum allowed value of log_scale - store_output_for_backprop: only possibly affects memory use; recommend - to set to True if you think the output of this module is more likely - than the input of this module to be required to be stored for the - backprop. - """ - - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - log_scale: float = 1.0, - log_scale_min: float = -1.5, - log_scale_max: float = 1.5, - store_output_for_backprop: bool = False, - ) -> None: - super(BiasNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.log_scale = nn.Parameter(torch.tensor(log_scale)) - self.bias = nn.Parameter(torch.zeros(num_channels)) - - self.log_scale_min = log_scale_min - self.log_scale_max = log_scale_max - - self.store_output_for_backprop = store_output_for_backprop - - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - channel_dim = self.channel_dim - if channel_dim < 0: - channel_dim += x.ndim - bias = self.bias - for _ in range(channel_dim + 1, x.ndim): - bias = bias.unsqueeze(-1) - scales = ( - torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 - ) * self.log_scale.exp() - return x * scales - - log_scale = limit_param_value( - self.log_scale, - min=float(self.log_scale_min), - max=float(self.log_scale_max), - training=self.training, - ) - - return BiasNormFunction.apply( - x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop - ) - - -def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: - """ - Behaves like a constructor of a modified version of nn.Linear - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Linear(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: - """ - Behaves like a constructor of a modified version of nn.Conv1d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv1d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: - """ - Behaves like a constructor of a modified version of nn.Conv2d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False, but: - NO PADDING-RELATED ARGS. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv2d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -class ChunkCausalDepthwiseConv1d(torch.nn.Module): - """ - Behaves like a depthwise 1d convolution, except that it is causal in - a chunkwise way, as if we had a block-triangular attention mask. - The chunk size is provided at test time (it should probably be - kept in sync with the attention mask). - - This has a little more than twice the parameters of a conventional - depthwise conv1d module: we implement it by having one - depthwise convolution, of half the width, that is causal (via - right-padding); and one depthwise convolution that is applied only - within chunks, that we multiply by a scaling factor which depends - on the position within the chunk. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - - def __init__( - self, - channels: int, - kernel_size: int, - initial_scale: float = 1.0, - bias: bool = True, - ): - super().__init__() - assert kernel_size % 2 == 1 - - half_kernel_size = (kernel_size + 1) // 2 - # will pad manually, on one side. - self.causal_conv = nn.Conv1d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=half_kernel_size, - padding=0, - bias=True, - ) - - self.chunkwise_conv = nn.Conv1d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=kernel_size // 2, - bias=bias, - ) - - # first row is correction factors added to the scale near the left edge of the chunk, - # second row is correction factors added to the scale near the right edge of the chunk, - # both of these are added to a default scale of 1.0. - self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) - self.kernel_size = kernel_size - - with torch.no_grad(): - self.causal_conv.weight[:] *= initial_scale - self.chunkwise_conv.weight[:] *= initial_scale - if bias: - torch.nn.init.uniform_( - self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale - ) - - def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: - """ - Forward function. Args: - x: a Tensor of shape (batch_size, channels, seq_len) - chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. - """ - (batch_size, num_channels, seq_len) = x.shape - - # half_kernel_size = self.kernel_size + 1 // 2 - # left_pad is half_kernel_size - 1 where half_kernel_size is the size used - # in the causal conv. It's the amount by which we must pad on the left, - # to make the convolution causal. - left_pad = self.kernel_size // 2 - - if chunk_size < 0 or chunk_size > seq_len: - chunk_size = seq_len - right_pad = -seq_len % chunk_size - - x = torch.nn.functional.pad(x, (left_pad, right_pad)) - - x_causal = self.causal_conv(x[..., : left_pad + seq_len]) - assert x_causal.shape == (batch_size, num_channels, seq_len) - - x_chunk = x[..., left_pad:] - num_chunks = x_chunk.shape[2] // chunk_size - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) - x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( - batch_size * num_chunks, num_channels, chunk_size - ) - x_chunk = self.chunkwise_conv(x_chunk) # does not change shape - - chunk_scale = self._get_chunk_scale(chunk_size) - - x_chunk = x_chunk * chunk_scale - x_chunk = x_chunk.reshape( - batch_size, num_chunks, num_channels, chunk_size - ).permute(0, 2, 1, 3) - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ - ..., :seq_len - ] - - return x_chunk + x_causal - - def _get_chunk_scale(self, chunk_size: int): - """Returns tensor of shape (num_channels, chunk_size) that will be used to - scale the output of self.chunkwise_conv.""" - left_edge = self.chunkwise_conv_scale[0] - right_edge = self.chunkwise_conv_scale[1] - if chunk_size < self.kernel_size: - left_edge = left_edge[:, :chunk_size] - right_edge = right_edge[:, -chunk_size:] - else: - t = chunk_size - self.kernel_size - channels = left_edge.shape[0] - pad = torch.zeros( - channels, t, device=left_edge.device, dtype=left_edge.dtype - ) - left_edge = torch.cat((left_edge, pad), dim=-1) - right_edge = torch.cat((pad, right_edge), dim=-1) - return 1.0 + (left_edge + right_edge) - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Streaming Forward function. - - Args: - x: a Tensor of shape (batch_size, channels, seq_len) - cache: cached left context of shape (batch_size, channels, left_pad) - """ - (batch_size, num_channels, seq_len) = x.shape - - # left_pad is half_kernel_size - 1 where half_kernel_size is the size used - # in the causal conv. It's the amount by which we must pad on the left, - # to make the convolution causal. - left_pad = self.kernel_size // 2 - - # Pad cache - assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad) - x = torch.cat([cache, x], dim=2) - # Update cache - cache = x[..., -left_pad:] - - x_causal = self.causal_conv(x) - assert x_causal.shape == (batch_size, num_channels, seq_len) - - x_chunk = x[..., left_pad:] - x_chunk = self.chunkwise_conv(x_chunk) # does not change shape - - chunk_scale = self._get_chunk_scale(chunk_size=seq_len) - x_chunk = x_chunk * chunk_scale - - return x_chunk + x_causal, cache - - -class BalancerFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - min_mean: float, - max_mean: float, - min_rms: float, - max_rms: float, - grad_scale: float, - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - ctx.save_for_backward(x) - ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: - (x,) = ctx.saved_tensors - (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config - - try: - with torch.enable_grad(): - with torch.amp.autocast("cuda", enabled=False): - x = x.to(torch.float32) - x = x.detach() - x.requires_grad = True - mean_dims = [i for i in range(x.ndim) if i != channel_dim] - uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) - mean = x.mean(dim=mean_dims, keepdim=True) - stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() - rms = uncentered_var.clamp(min=1.0e-20).sqrt() - - m = mean / stddev - # part of loss that relates to mean / stddev - m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs() - - # put a much larger scale on the RMS-max-limit loss, so that if both it and the - # m_loss are violated we fix the RMS loss first. - rms_clamped = rms.clamp(min=min_rms, max=max_rms) - r_loss = (rms_clamped / rms).log().abs() - - loss = m_loss + r_loss - - loss.backward(gradient=torch.ones_like(loss)) - loss_grad = x.grad - loss_grad_rms = ( - (loss_grad**2) - .mean(dim=mean_dims, keepdim=True) - .sqrt() - .clamp(min=1.0e-20) - ) - - loss_grad = loss_grad * (grad_scale / loss_grad_rms) - - x_grad_float = x_grad.to(torch.float32) - # scale each element of loss_grad by the absolute value of the corresponding - # element of x_grad, which we view as a noisy estimate of its magnitude for that - # (frame and dimension). later we can consider factored versions. - x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) - x_grad = x_grad_mod.to(x_grad.dtype) - except Exception as e: - logging.info( - f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." - ) - - return x_grad, None, None, None, None, None, None - - -class Balancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), above which we start to modify the derivatives. - scale_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_abs and max_abs - are violated. - min_abs: the minimum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - max_abs: the maximum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - prob: determines the minimum probability with which we modify the - gradients for the {min,max}_positive and {min,max}_abs constraints, - on each forward(). This is done randomly to prevent all layers - from doing it at the same time. - """ - - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: FloatLike = 0.05, - max_positive: FloatLike = 0.95, - min_abs: FloatLike = 0.2, - max_abs: FloatLike = 100.0, - grad_scale: FloatLike = 0.04, - prob: Optional[FloatLike] = None, - ): - super().__init__() - - if prob is None: - prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4) - self.prob = prob - # 5% of the time we will return and do nothing because memory usage is - # too high. - self.mem_cutoff = CutoffEstimator(0.05) - - # actually self.num_channels is no longer needed except for an assertion. - self.num_channels = num_channels - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.min_abs = min_abs - self.max_abs = max_abs - self.grad_scale = grad_scale - - def forward(self, x: Tensor) -> Tensor: - if ( - torch.jit.is_scripting() - or not x.requires_grad - or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) - ): - return _no_op(x) - - prob = float(self.prob) - if random.random() < prob: - # The following inner-functions convert from the way we historically specified - # these limitations, as limits on the absolute value and the proportion of positive - # values, to limits on the RMS value and the (mean / stddev). - def _abs_to_rms(x): - # for normally distributed data, if the expected absolute value is x, the - # expected rms value will be sqrt(pi/2) * x. - return 1.25331413732 * x - - def _proportion_positive_to_mean(x): - def _atanh(x): - eps = 1.0e-10 - # eps is to prevent crashes if x is exactly 0 or 1. - # we'll just end up returning a fairly large value. - return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 - - def _approx_inverse_erf(x): - # 1 / (sqrt(pi) * ln(2)), - # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions - # this approximation is extremely crude and gets progressively worse for - # x very close to -1 or +1, but we mostly care about the "middle" region - # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772, - # and math.erf(0.0407316414078772) = 0.045935330944660666, - # which is pretty close to 0.05. - return 0.8139535143 * _atanh(x) - - # first convert x from the range 0..1 to the range -1..1 which the error - # function returns - x = -1 + (2 * x) - return _approx_inverse_erf(x) - - min_mean = _proportion_positive_to_mean(float(self.min_positive)) - max_mean = _proportion_positive_to_mean(float(self.max_positive)) - min_rms = _abs_to_rms(float(self.min_abs)) - max_rms = _abs_to_rms(float(self.max_abs)) - grad_scale = float(self.grad_scale) - - assert x.shape[self.channel_dim] == self.num_channels - - return BalancerFunction.apply( - x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim - ) - else: - return _no_op(x) - - -def penalize_abs_values_gt( - x: Tensor, limit: float, penalty: float, name: str = None -) -> Tensor: - """ - Returns x unmodified, but in backprop will put a penalty for the excess of - the absolute values of elements of x over the limit "limit". E.g. if - limit == 10.0, then if x has any values over 10 it will get a penalty. - - Caution: the value of this penalty will be affected by grad scaling used - in automatic mixed precision training. For this reasons we use this, - it shouldn't really matter, or may even be helpful; we just use this - to disallow really implausible values of scores to be given to softmax. - - The name is for randomly printed debug info. - """ - x_sign = x.sign() - over_limit = (x.abs() - limit) > 0 - # The following is a memory efficient way to penalize the absolute values of - # x that's over the limit. (The memory efficiency comes when you think - # about which items torch needs to cache for the autograd, and which ones it - # can throw away). The numerical value of aux_loss as computed here will - # actually be larger than it should be, by limit * over_limit.sum(), but it - # has the same derivative as the real aux_loss which is penalty * (x.abs() - - # limit).relu(). - aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) - # note: we don't do sum() here on aux)_loss, but it's as if we had done - # sum() due to how with_loss() works. - x = with_loss(x, aux_loss, name) - # you must use x for something, or this will be ineffective. - return x - - -def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. - if x.ndim == 2: - return x.diag() - else: - (batch, dim, dim) = x.shape - x = x.reshape(batch, dim * dim) - x = x[:, :: dim + 1] - assert x.shape == (batch, dim) - return x - - -def _whitening_metric(x: Tensor, num_groups: int): - """ - Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of - of the centered feature covariance are the same within each group's covariance matrix - and also between groups. - Args: - x: a Tensor of shape (*, num_channels) - num_groups: the number of groups of channels, a number >=1 that divides num_channels - Returns: - Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and - greater than 1.0 otherwise. - """ - assert x.dtype != torch.float16 - x = x.reshape(-1, x.shape[-1]) - (num_frames, num_channels) = x.shape - assert num_channels % num_groups == 0 - channels_per_group = num_channels // num_groups - x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) - # x now has shape (num_groups, num_frames, channels_per_group) - # subtract the mean so we use the centered, not uncentered, covariance. - # My experience has been that when we "mess with the gradients" like this, - # it's better not do anything that tries to move the mean around, because - # that can easily cause instability. - x = x - x.mean(dim=1, keepdim=True) - # x_covar: (num_groups, channels_per_group, channels_per_group) - x_covar = torch.matmul(x.transpose(1, 2), x) - x_covar_mean_diag = _diag(x_covar).mean() - # the following expression is what we'd get if we took the matrix product - # of each covariance and measured the mean of its trace, i.e. - # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) - # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) - return metric - - -class WhiteningPenaltyFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: - ctx.save_for_backward(x) - ctx.module = module - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - (x_orig,) = ctx.saved_tensors - w = ctx.module - - try: - with torch.enable_grad(): - with torch.amp.autocast("cuda", enabled=False): - x_detached = x_orig.to(torch.float32).detach() - x_detached.requires_grad = True - - metric = _whitening_metric(x_detached, w.num_groups) - - if random.random() < 0.005 or __name__ == "__main__": - logging.debug( - f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}" - ) - - if metric < float(w.whitening_limit): - w.prob = w.min_prob - return x_grad, None - else: - w.prob = w.max_prob - metric.backward() - penalty_grad = x_detached.grad - scale = w.grad_scale * ( - x_grad.to(torch.float32).norm() - / (penalty_grad.norm() + 1.0e-20) - ) - penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None - except Exception as e: - logging.info( - f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue." - ) - return x_grad, None - - -class Whiten(nn.Module): - def __init__( - self, - num_groups: int, - whitening_limit: FloatLike, - prob: Union[float, Tuple[float, float]], - grad_scale: FloatLike, - ): - """ - Args: - num_groups: the number of groups to divide the channel dim into before - whitening. We will attempt to make the feature covariance - within each group, after mean subtraction, as "white" as possible, - while having the same trace across all groups. - whitening_limit: a value greater than 1.0, that dictates how much - freedom we have to violate the constraints. 1.0 would mean perfectly - white, with exactly the same trace across groups; larger values - give more freedom. E.g. 2.0. - prob: the probability with which we apply the gradient modification - (also affects the grad scale). May be supplied as a float, - or as a pair (min_prob, max_prob) - - grad_scale: determines the scale on the gradient term from this object, - relative to the rest of the gradient on the attention weights. - E.g. 0.02 (you may want to use smaller values than this if prob is large) - """ - super(Whiten, self).__init__() - assert num_groups >= 1 - assert float(whitening_limit) >= 1 - assert grad_scale >= 0 - self.num_groups = num_groups - self.whitening_limit = whitening_limit - self.grad_scale = grad_scale - - if isinstance(prob, float): - prob = (prob, prob) - (self.min_prob, self.max_prob) = prob - assert 0 < self.min_prob <= self.max_prob <= 1 - self.prob = self.max_prob - self.name = None # will be set in training loop - - def forward(self, x: Tensor) -> Tensor: - """ - In the forward pass, this function just returns the input unmodified. - In the backward pass, it will modify the gradients to ensure that the - distribution in each group has close to (lambda times I) as the covariance - after mean subtraction, with the same lambda across groups. - For whitening_limit > 1, there will be more freedom to violate this - constraint. - - Args: - x: the input of shape (*, num_channels) - - Returns: - x, unmodified. You should make sure - you use the returned value, or the graph will be freed - and nothing will happen in backprop. - """ - grad_scale = float(self.grad_scale) - if not x.requires_grad or random.random() > self.prob or grad_scale == 0: - return _no_op(x) - else: - return WhiteningPenaltyFunction.apply(x, self) - - -class WithLoss(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, y: Tensor, name: str): - ctx.y_shape = y.shape - if random.random() < 0.002 and name is not None: - loss_sum = y.sum().item() - logging.debug(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor): - return ( - ans_grad, - torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), - None, - ) - - -def with_loss(x, y, name): - # returns x but adds y.sum() to the loss function. - return WithLoss.apply(x, y, name) - - -class ScaleGradFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, alpha: float) -> Tensor: - ctx.alpha = alpha - return x - - @staticmethod - def backward(ctx, grad: Tensor): - return grad * ctx.alpha, None - - -def scale_grad(x: Tensor, alpha: float): - return ScaleGradFunction.apply(x, alpha) - - -class ScaleGrad(nn.Module): - def __init__(self, alpha: float): - super().__init__() - self.alpha = alpha - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return x - return scale_grad(x, self.alpha) - - -class LimitParamValue(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, min: float, max: float): - ctx.save_for_backward(x) - assert max >= min - ctx.min = min - ctx.max = max - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - (x,) = ctx.saved_tensors - # where x < ctx.min, ensure all grads are negative (this will tend to make - # x more positive). - x_grad = x_grad * torch.where( - torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 - ) - # where x > ctx.max, ensure all grads are positive (this will tend to make - # x more negative). - x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) - return x_grad, None, None - - -def limit_param_value( - x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True -): - # You apply this to (typically) an nn.Parameter during training to ensure that its - # (elements mostly) stays within a supplied range. This is done by modifying the - # gradients in backprop. - # It's not necessary to do this on every batch: do it only some of the time, - # to save a little time. - if training and random.random() < prob: - return LimitParamValue.apply(x, min, max) - else: - return x - - -def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x - else: - # a no-op function that will have a node in the autograd graph, - # to avoid certain bugs relating to backward hooks - return x.chunk(1, dim=-1)[0] - - -class Identity(torch.nn.Module): - def __init__(self): - super(Identity, self).__init__() - - def forward(self, x): - return _no_op(x) - - -class DoubleSwishFunction(torch.autograd.Function): - """ - double_swish(x) = x * torch.sigmoid(x-1) - - This is a definition, originally motivated by its close numerical - similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). - - Memory-efficient derivative computation: - double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) - double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). - Now, s'(x) = s(x) * (1-s(x)). - double_swish'(x) = x * s'(x) + s(x). - = x * s(x) * (1-s(x)) + s(x). - = double_swish(x) * (1-s(x)) + s(x) - ... so we just need to remember s(x) but not x itself. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - if x.dtype == torch.float16: - x = x.to(torch.float32) - - s = torch.sigmoid(x - 1.0) - y = x * s - - if requires_grad: - deriv = y * (1 - s) + s - - # notes on derivative of x * sigmoid(x - 1): - # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 - # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund - # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. - # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which - # floors), should be expectation-preserving. - floor = -0.044 - ceil = 1.2 - d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - deriv - ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.043637 - ceil = 1.2 - - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d - - -class DoubleSwish(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x * torch.sigmoid(x - 1.0) - return DoubleSwishFunction.apply(x) - - -# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. -class Dropout2(nn.Module): - def __init__(self, p: FloatLike): - super().__init__() - self.p = p - - def forward(self, x: Tensor) -> Tensor: - return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) - - -class MulForDropout3(torch.autograd.Function): - # returns (x * y * alpha) where alpha is a float and y doesn't require - # grad and is zero-or-one. - @staticmethod - @custom_fwd - def forward(ctx, x, y, alpha): - assert not y.requires_grad - ans = x * y * alpha - ctx.save_for_backward(ans) - ctx.alpha = alpha - return ans - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad): - (ans,) = ctx.saved_tensors - x_grad = ctx.alpha * ans_grad * (ans != 0) - return x_grad, None, None - - -# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, -# and it lets you choose one dimension to share the dropout mask over -class Dropout3(nn.Module): - def __init__(self, p: FloatLike, shared_dim: int): - super().__init__() - self.p = p - self.shared_dim = shared_dim - - def forward(self, x: Tensor) -> Tensor: - p = float(self.p) - if not self.training or p == 0: - return _no_op(x) - scale = 1.0 / (1 - p) - rand_shape = list(x.shape) - rand_shape[self.shared_dim] = 1 - mask = torch.rand(*rand_shape, device=x.device) > p - ans = MulForDropout3.apply(x, mask, scale) - return ans - - -class SwooshLFunction(torch.autograd.Function): - """ - swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - if x.dtype == torch.float16: - x = x.to(torch.float32) - - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - - coeff = -0.08 - - with torch.amp.autocast("cuda", enabled=False): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 - - if not requires_grad: - return y - - y.backward(gradient=torch.ones_like(y)) - - grad = x.grad - floor = coeff - ceil = 1.0 + coeff + 0.005 - - d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - grad - ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - - coeff = -0.08 - floor = coeff - ceil = 1.0 + coeff + 0.005 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d - - -class SwooshL(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation.""" - if ( - torch.jit.is_scripting() - or torch.jit.is_tracing() - or "k2" not in sys.modules - ): - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 - if not x.requires_grad: - return k2.swoosh_l_forward(x) - else: - return k2.swoosh_l(x) - # return SwooshLFunction.apply(x) - - -class SwooshLOnnx(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation.""" - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 - - -class SwooshRFunction(torch.autograd.Function): - """ - swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 - - derivatives are between -0.08 and 0.92. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - - if x.dtype == torch.float16: - x = x.to(torch.float32) - - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - - with torch.amp.autocast("cuda", enabled=False): - with torch.enable_grad(): - x = x.detach() - x.requires_grad = True - y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 - - if not requires_grad: - return y - y.backward(gradient=torch.ones_like(y)) - - grad = x.grad - floor = -0.08 - ceil = 0.925 - - d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - grad - ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.08 - ceil = 0.925 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d - - -class SwooshR(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation.""" - if ( - torch.jit.is_scripting() - or torch.jit.is_tracing() - or "k2" not in sys.modules - ): - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 - if not x.requires_grad: - return k2.swoosh_r_forward(x) - else: - return k2.swoosh_r(x) - # return SwooshRFunction.apply(x) - - -class SwooshROnnx(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation.""" - zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687 - - -# simple version of SwooshL that does not redefine the backprop, used in -# ActivationDropoutAndLinearFunction. -def SwooshLForward(x: Tensor): - x_offset = x - 4.0 - log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) - return log_sum - 0.08 * x - 0.035 - - -# simple version of SwooshR that does not redefine the backprop, used in -# ActivationDropoutAndLinearFunction. -def SwooshRForward(x: Tensor): - x_offset = x - 1.0 - log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) - return log_sum - 0.08 * x - 0.313261687 - - -class ActivationDropoutAndLinearFunction(torch.autograd.Function): - @staticmethod - @custom_fwd - def forward( - ctx, - x: Tensor, - weight: Tensor, - bias: Optional[Tensor], - activation: str, - dropout_p: float, - dropout_shared_dim: Optional[int], - ): - if dropout_p != 0.0: - dropout_shape = list(x.shape) - if dropout_shared_dim is not None: - dropout_shape[dropout_shared_dim] = 1 - # else it won't be very memory efficient. - dropout_mask = (1.0 / (1.0 - dropout_p)) * ( - torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p - ) - else: - dropout_mask = None - - ctx.save_for_backward(x, weight, bias, dropout_mask) - - ctx.activation = activation - - forward_activation_dict = { - "SwooshL": k2.swoosh_l_forward, - "SwooshR": k2.swoosh_r_forward, - } - # it will raise a KeyError if this fails. This will be an error. We let it - # propagate to the user. - activation_func = forward_activation_dict[activation] - x = activation_func(x) - if dropout_mask is not None: - x = x * dropout_mask - x = torch.nn.functional.linear(x, weight, bias) - return x - - @staticmethod - @custom_bwd - def backward(ctx, ans_grad: Tensor): - saved = ctx.saved_tensors - (x, weight, bias, dropout_mask) = saved - - forward_and_deriv_activation_dict = { - "SwooshL": k2.swoosh_l_forward_and_deriv, - "SwooshR": k2.swoosh_r_forward_and_deriv, - } - # the following lines a KeyError if the activation is unrecognized. - # This will be an error. We let it propagate to the user. - func = forward_and_deriv_activation_dict[ctx.activation] - - y, func_deriv = func(x) - if dropout_mask is not None: - y = y * dropout_mask - # now compute derivative of y w.r.t. weight and bias.. - # y: (..., in_channels), ans_grad: (..., out_channels), - (out_channels, in_channels) = weight.shape - - in_channels = y.shape[-1] - g = ans_grad.reshape(-1, out_channels) - weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) - y_deriv = torch.matmul(ans_grad, weight) - bias_deriv = None if bias is None else g.sum(dim=0) - x_deriv = y_deriv * func_deriv - if dropout_mask is not None: - # order versus func_deriv does not matter - x_deriv = x_deriv * dropout_mask - - return x_deriv, weight_deriv, bias_deriv, None, None, None - - -class ActivationDropoutAndLinear(torch.nn.Module): - """ - This merges an activation function followed by dropout and then a nn.Linear module; - it does so in a memory efficient way so that it only stores the input to the whole - module. If activation == SwooshL and dropout_shared_dim != None, this will be - equivalent to: - nn.Sequential(SwooshL(), - Dropout3(dropout_p, shared_dim=dropout_shared_dim), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=initial_scale)) - If dropout_shared_dim is None, the dropout would be equivalent to - Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout - mask is smaller. - - Args: - in_channels: number of input channels, e.g. 256 - out_channels: number of output channels, e.g. 256 - bias: if true, have a bias - activation: the activation function, for now just support SwooshL. - dropout_p: the dropout probability or schedule (happens after nonlinearity). - dropout_shared_dim: the dimension, if any, across which the dropout mask is - shared (e.g. the time dimension). If None, this may be less memory - efficient if there are modules before this one that cache the input - for their backprop (e.g. Balancer or Whiten). - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - bias: bool = True, - activation: str = "SwooshL", - dropout_p: FloatLike = 0.0, - dropout_shared_dim: Optional[int] = -1, - initial_scale: float = 1.0, - ): - super().__init__() - # create a temporary module of nn.Linear that we'll steal the - # weights and bias from - l = ScaledLinear( - in_channels, out_channels, bias=bias, initial_scale=initial_scale - ) - - self.weight = l.weight - # register_parameter properly handles making it a parameter when l.bias - # is None. I think there is some reason for doing it this way rather - # than just setting it to None but I don't know what it is, maybe - # something to do with exporting the module.. - self.register_parameter("bias", l.bias) - - self.activation = activation - self.dropout_p = dropout_p - self.dropout_shared_dim = dropout_shared_dim - - def forward(self, x: Tensor): - if ( - torch.jit.is_scripting() - or torch.jit.is_tracing() - or "k2" not in sys.modules - ): - if self.activation == "SwooshL": - x = SwooshLForward(x) - elif self.activation == "SwooshR": - x = SwooshRForward(x) - else: - assert False, self.activation - return torch.nn.functional.linear(x, self.weight, self.bias) - - return ActivationDropoutAndLinearFunction.apply( - x, - self.weight, - self.bias, - self.activation, - float(self.dropout_p), - self.dropout_shared_dim, - ) - - -def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: - if num_channels <= x.shape[-1]: - return x[..., :num_channels] - else: - shape = list(x.shape) - shape[-1] = num_channels - shape[-1] - zeros = torch.zeros(shape, dtype=x.dtype, device=x.device) - return torch.cat((x, zeros), dim=-1) - - -def _test_whiten(): - for proportion in [0.1, 0.5, 10.0]: - logging.info(f"_test_whiten(): proportion = {proportion}") - x = torch.randn(100, 128) - direction = torch.randn(128) - coeffs = torch.randn(100, 1) - x += proportion * direction * coeffs - - x.requires_grad = True - - m = Whiten( - 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, - ) # grad_scale - - for _ in range(4): - y = m(x) - - y_grad = torch.randn_like(x) - y.backward(gradient=y_grad) - - if proportion < 0.2: - assert torch.allclose(x.grad, y_grad) - elif proportion > 1.0: - assert not torch.allclose(x.grad, y_grad) - - -def _test_balancer_sign(): - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) - x = x.detach() - x.requires_grad = True - m = Balancer( - probs.numel(), - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - min_abs=0.0, - prob=1.0, - ) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_balancer_sign: x = ", x) - print("_test_balancer_sign: y grad = ", y_grad) - print("_test_balancer_sign: x grad = ", x.grad) - - -def _test_balancer_magnitude(): - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) - x = x.detach() - x.requires_grad = True - m = Balancer( - magnitudes.numel(), - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - min_abs=0.2, - max_abs=0.7, - prob=1.0, - ) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_balancer_magnitude: x = ", x) - print("_test_balancer_magnitude: y grad = ", y_grad) - print("_test_balancer_magnitude: x grad = ", x.grad) - - -def _test_double_swish_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = DoubleSwish() - - tol = (1.2 - (-0.043637)) / 255.0 - torch.autograd.gradcheck(m, x, atol=tol) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_swooshl_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = SwooshL() - - tol = 1.0 / 255.0 - torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_swooshr_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = SwooshR() - - tol = 1.0 / 255.0 - torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_softmax(): - a = torch.randn(2, 10, dtype=torch.float64) - b = a.clone() - a.requires_grad = True - b.requires_grad = True - a.softmax(dim=1)[:, 0].sum().backward() - print("a grad = ", a.grad) - softmax(b, dim=1)[:, 0].sum().backward() - print("b grad = ", b.grad) - assert torch.allclose(a.grad, b.grad) - - -def _test_piecewise_linear(): - p = PiecewiseLinear((0, 10.0)) - for x in [-100, 0, 100]: - assert p(x) == 10.0 - p = PiecewiseLinear((0, 10.0), (1, 0.0)) - for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]: - print("x, y = ", x, y) - assert p(x) == y, (x, p(x), y) - - q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) - x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0] - pq = p.max(q) - for x in x_vals: - y1 = max(p(x), q(x)) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - pq = p.min(q) - for x in x_vals: - y1 = min(p(x), q(x)) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - pq = p + q - for x in x_vals: - y1 = p(x) + q(x) - y2 = pq(x) - assert abs(y1 - y2) < 0.001 - - -def _test_activation_dropout_and_linear(): - in_channels = 20 - out_channels = 30 - - for bias in [True, False]: - # actually we don't test for dropout_p != 0.0 because forward functions will give - # different answers. This is because we are using the k2 implementation of - # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() - # internally, messing up the random state. - for dropout_p in [0.0]: - for activation in ["SwooshL", "SwooshR"]: - m1 = nn.Sequential( - SwooshL() if activation == "SwooshL" else SwooshR(), - Dropout3(p=dropout_p, shared_dim=-1), - ScaledLinear( - in_channels, out_channels, bias=bias, initial_scale=0.5 - ), - ) - m2 = ActivationDropoutAndLinear( - in_channels, - out_channels, - bias=bias, - initial_scale=0.5, - activation=activation, - dropout_p=dropout_p, - ) - with torch.no_grad(): - m2.weight[:] = m1[2].weight - if bias: - m2.bias[:] = m1[2].bias - # make sure forward gives same result. - x1 = torch.randn(10, in_channels) - x1.requires_grad = True - - # TEMP. - assert torch.allclose( - SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03 - ) - - x2 = x1.clone().detach() - x2.requires_grad = True - seed = 10 - torch.manual_seed(seed) - y1 = m1(x1) - y_grad = torch.randn_like(y1) - y1.backward(gradient=y_grad) - torch.manual_seed(seed) - y2 = m2(x2) - y2.backward(gradient=y_grad) - - print( - f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}" - ) - print("y1 = ", y1) - print("y2 = ", y2) - assert torch.allclose(y1, y2, atol=0.02) - assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) - if bias: - assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) - print("x1.grad = ", x1.grad) - print("x2.grad = ", x2.grad) - - def isclose(a, b): - # return true if cosine similarity is > 0.9. - return (a * b).sum() > 0.9 * ( - (a**2).sum() * (b**2).sum() - ).sqrt() - - # the SwooshL() implementation has a noisy gradient due to 1-byte - # storage of it. - assert isclose(x1.grad, x2.grad) - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.DEBUG) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_piecewise_linear() - _test_softmax() - _test_whiten() - _test_balancer_sign() - _test_balancer_magnitude() - _test_double_swish_deriv() - _test_swooshr_deriv() - _test_swooshl_deriv() - _test_activation_dropout_and_linear() diff --git a/egs/zipvoice/zipvoice/solver.py b/egs/zipvoice/zipvoice/solver.py deleted file mode 100644 index a1e316ec8..000000000 --- a/egs/zipvoice/zipvoice/solver.py +++ /dev/null @@ -1,277 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Union - -import torch - - -class DiffusionModel(torch.nn.Module): - """A wrapper of diffusion models for inference. - Args: - model: The diffusion model. - distill: Whether it is a distillation model. - """ - - def __init__( - self, - model: torch.nn.Module, - distill: bool = False, - func_name: str = "forward_fm_decoder", - ): - super().__init__() - self.model = model - self.distill = distill - self.func_name = func_name - self.model_func = getattr(self.model, func_name) - - def forward( - self, - t: torch.Tensor, - x: torch.Tensor, - text_condition: torch.Tensor, - speech_condition: torch.Tensor, - padding_mask: Optional[torch.Tensor] = None, - guidance_scale: Union[float, torch.Tensor] = 0.0, - **kwargs - ) -> torch.Tensor: - """ - Forward function that Handles the classifier-free guidance. - Args: - t: The current timestep, a tensor of shape (batch, 1, 1) or a tensor of a single float. - x: The initial value, with the shape (batch, seq_len, emb_dim). - text_condition: The text_condition of the diffision model, with the shape (batch, seq_len, emb_dim). - speech_condition: The speech_condition of the diffision model, with the shape (batch, seq_len, emb_dim). - padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len). - guidance_scale: The scale of classifier-free guidance, a float or a tensor of shape (batch, 1, 1). - Retrun: - The prediction with the shape (batch, seq_len, emb_dim). - """ - if not torch.is_tensor(guidance_scale): - guidance_scale = torch.tensor( - guidance_scale, dtype=t.dtype, device=t.device - ) - if self.distill: - return self.model_func( - t=t, - xt=x, - text_condition=text_condition, - speech_condition=speech_condition, - padding_mask=padding_mask, - guidance_scale=guidance_scale, - **kwargs - ) - - if (guidance_scale == 0.0).all(): - return self.model_func( - t=t, - xt=x, - text_condition=text_condition, - speech_condition=speech_condition, - padding_mask=padding_mask, - **kwargs - ) - else: - if t.dim() != 0: - t = torch.cat([t] * 2, dim=0) - - x = torch.cat([x] * 2, dim=0) - padding_mask = torch.cat([padding_mask] * 2, dim=0) - - text_condition = torch.cat( - [torch.zeros_like(text_condition), text_condition], dim=0 - ) - - if t.dim() == 0: - if t > 0.5: - speech_condition = torch.cat( - [torch.zeros_like(speech_condition), speech_condition], dim=0 - ) - else: - guidance_scale = guidance_scale * 2 - speech_condition = torch.cat( - [speech_condition, speech_condition], dim=0 - ) - else: - assert t.dim() > 0, t - larger_t_index = (t > 0.5).squeeze(1).squeeze(1) - zero_speech_condition = torch.cat( - [torch.zeros_like(speech_condition), speech_condition], dim=0 - ) - speech_condition = torch.cat( - [speech_condition, speech_condition], dim=0 - ) - speech_condition[larger_t_index] = zero_speech_condition[larger_t_index] - guidance_scale[~larger_t_index[: larger_t_index.size(0) // 2]] *= 2 - - data_uncond, data_cond = self.model_func( - t=t, - xt=x, - text_condition=text_condition, - speech_condition=speech_condition, - padding_mask=padding_mask, - **kwargs - ).chunk(2, dim=0) - - res = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond - return res - - -class EulerSolver: - def __init__( - self, - model: torch.nn.Module, - distill: bool = False, - func_name: str = "forward_fm_decoder", - ): - """Construct a Euler Solver - Args: - model: The diffusion model. - distill: Whether it is distillation model. - """ - - self.model = DiffusionModel(model, distill=distill, func_name=func_name) - - def sample( - self, - x: torch.Tensor, - text_condition: torch.Tensor, - speech_condition: torch.Tensor, - padding_mask: torch.Tensor, - num_step: int = 10, - guidance_scale: Union[float, torch.Tensor] = 0.0, - t_start: Union[float, torch.Tensor] = 0.0, - t_end: Union[float, torch.Tensor] = 1.0, - t_shift: float = 1.0, - **kwargs - ) -> torch.Tensor: - """ - Compute the sample at time `t_end` by Euler Solver. - Args: - x: The initial value at time `t_start`, with the shape (batch, seq_len, emb_dim). - text_condition: The text condition of the diffision mode, with the shape (batch, seq_len, emb_dim). - speech_condition: The speech condition of the diffision model, with the shape (batch, seq_len, emb_dim). - padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len). - num_step: The number of ODE steps. - guidance_scale: The scale for classifier-free guidance, which is - a float or a tensor with the shape (batch, 1, 1). - t_start: the start timestep in the range of [0, 1], - which is a float or a tensor with the shape (batch, 1, 1). - t_end: the end time_step in the range of [0, 1], - which is a float or a tensor with the shape (batch, 1, 1). - t_shift: shift the t toward smaller numbers so that the sampling - will emphasize low SNR region. Should be in the range of (0, 1]. - The shifting will be more significant when the number is smaller. - - Returns: - The approximated solution at time `t_end`. - """ - device = x.device - - if torch.is_tensor(t_start) and t_start.dim() > 0: - timesteps = get_time_steps_batch( - t_start=t_start, - t_end=t_end, - num_step=num_step, - t_shift=t_shift, - device=device, - ) - else: - timesteps = get_time_steps( - t_start=t_start, - t_end=t_end, - num_step=num_step, - t_shift=t_shift, - device=device, - ) - for step in range(num_step): - v = self.model( - t=timesteps[step], - x=x, - text_condition=text_condition, - speech_condition=speech_condition, - padding_mask=padding_mask, - guidance_scale=guidance_scale, - **kwargs - ) - x = x + v * (timesteps[step + 1] - timesteps[step]) - return x - - -def get_time_steps( - t_start: float = 0.0, - t_end: float = 1.0, - num_step: int = 10, - t_shift: float = 1.0, - device: torch.device = torch.device("cpu"), -) -> torch.Tensor: - """Compute the intermediate time steps for sampling. - - Args: - t_start: The starting time of the sampling (default is 0). - t_end: The starting time of the sampling (default is 1). - num_step: The number of sampling. - t_shift: shift the t toward smaller numbers so that the sampling - will emphasize low SNR region. Should be in the range of (0, 1]. - The shifting will be more significant when the number is smaller. - device: A torch device. - Returns: - The time step with the shape (num_step + 1,). - """ - - timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device) - - timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps) - - return timesteps - - -def get_time_steps_batch( - t_start: torch.Tensor, - t_end: torch.Tensor, - num_step: int = 10, - t_shift: float = 1.0, - device: torch.device = torch.device("cpu"), -) -> torch.Tensor: - """Compute the intermediate time steps for sampling in the batch mode. - - Args: - t_start: The starting time of the sampling (default is 0), with the shape (batch, 1, 1). - t_end: The starting time of the sampling (default is 1), with the shape (batch, 1, 1). - num_step: The number of sampling. - t_shift: shift the t toward smaller numbers so that the sampling - will emphasize low SNR region. Should be in the range of (0, 1]. - The shifting will be more significant when the number is smaller. - device: A torch device. - Returns: - The time step with the shape (num_step + 1, N, 1, 1). - """ - while t_start.dim() > 1 and t_start.size(-1) == 1: - t_start = t_start.squeeze(-1) - while t_end.dim() > 1 and t_end.size(-1) == 1: - t_end = t_end.squeeze(-1) - assert t_start.dim() == t_end.dim() == 1 - - timesteps_shape = (num_step + 1, t_start.size(0)) - timesteps = torch.zeros(timesteps_shape, device=device) - - for i in range(t_start.size(0)): - timesteps[:, i] = torch.linspace(t_start[i], t_end[i], steps=num_step + 1) - - timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps) - - return timesteps.unsqueeze(-1).unsqueeze(-1) diff --git a/egs/zipvoice/zipvoice/tokenizer.py b/egs/zipvoice/zipvoice/tokenizer.py deleted file mode 100644 index ea25d3498..000000000 --- a/egs/zipvoice/zipvoice/tokenizer.py +++ /dev/null @@ -1,572 +0,0 @@ -# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao -# Han Zhu, -# Wei Kang) -# -# See ../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re -import unicodedata -from functools import reduce -from typing import Dict, List, Optional - -import cn2an -import inflect -import jieba -from pypinyin import Style, lazy_pinyin -from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials - -try: - from piper_phonemize import phonemize_espeak -except Exception as ex: - raise RuntimeError( - f"{ex}\nPlease run\n" - "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html" - ) - -_inflect = inflect.engine() -_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") -_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") -_percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)") -_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") -_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") -_fraction_re = re.compile(r"([0-9]+)/([0-9]+)") -_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") -_number_re = re.compile(r"[0-9]+") - -# List of (regular expression, replacement) pairs for abbreviations: -_abbreviations = [ - (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1]) - for x in [ - ("mrs", "misess"), - ("mr", "mister"), - ("dr", "doctor"), - ("st", "saint"), - ("co", "company"), - ("jr", "junior"), - ("maj", "major"), - ("gen", "general"), - ("drs", "doctors"), - ("rev", "reverend"), - ("lt", "lieutenant"), - ("hon", "honorable"), - ("sgt", "sergeant"), - ("capt", "captain"), - ("esq", "esquire"), - ("ltd", "limited"), - ("col", "colonel"), - ("ft", "fort"), - ("etc", "et cetera"), - ("btw", "by the way"), - ] -] - - -def intersperse(sequence, item=0): - result = [item] * (len(sequence) * 2 + 1) - result[1::2] = sequence - return result - - -def expand_abbreviations(text): - for regex, replacement in _abbreviations: - text = re.sub(regex, replacement, text) - return text - - -def _remove_commas(m): - return m.group(1).replace(",", "") - - -def _expand_decimal_point(m): - return m.group(1).replace(".", " point ") - - -def _expand_percent(m): - return m.group(1).replace("%", " percent ") - - -def _expand_dollars(m): - match = m.group(1) - parts = match.split(".") - if len(parts) > 2: - return " " + match + " dollars " # Unexpected format - dollars = int(parts[0]) if parts[0] else 0 - cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 - if dollars and cents: - dollar_unit = "dollar" if dollars == 1 else "dollars" - cent_unit = "cent" if cents == 1 else "cents" - return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit) - elif dollars: - dollar_unit = "dollar" if dollars == 1 else "dollars" - return " %s %s " % (dollars, dollar_unit) - elif cents: - cent_unit = "cent" if cents == 1 else "cents" - return " %s %s " % (cents, cent_unit) - else: - return " zero dollars " - - -def fraction_to_words(numerator, denominator): - if numerator == 1 and denominator == 2: - return " one half " - if numerator == 1 and denominator == 4: - return " one quarter " - if denominator == 2: - return " " + _inflect.number_to_words(numerator) + " halves " - if denominator == 4: - return " " + _inflect.number_to_words(numerator) + " quarters " - return ( - " " - + _inflect.number_to_words(numerator) - + " " - + _inflect.ordinal(_inflect.number_to_words(denominator)) - + " " - ) - - -def _expand_fraction(m): - numerator = int(m.group(1)) - denominator = int(m.group(2)) - return fraction_to_words(numerator, denominator) - - -def _expand_ordinal(m): - return " " + _inflect.number_to_words(m.group(0)) + " " - - -def _expand_number(m): - num = int(m.group(0)) - if num > 1000 and num < 3000: - if num == 2000: - return " two thousand " - elif num > 2000 and num < 2010: - return " two thousand " + _inflect.number_to_words(num % 100) + " " - elif num % 100 == 0: - return " " + _inflect.number_to_words(num // 100) + " hundred " - else: - return ( - " " - + _inflect.number_to_words(num, andword="", zero="oh", group=2).replace( - ", ", " " - ) - + " " - ) - else: - return " " + _inflect.number_to_words(num, andword="") + " " - - -# Normalize numbers pronunciation -def normalize_numbers(text): - text = re.sub(_comma_number_re, _remove_commas, text) - text = re.sub(_pounds_re, r"\1 pounds", text) - text = re.sub(_dollars_re, _expand_dollars, text) - text = re.sub(_fraction_re, _expand_fraction, text) - text = re.sub(_decimal_number_re, _expand_decimal_point, text) - text = re.sub(_percent_number_re, _expand_percent, text) - text = re.sub(_ordinal_re, _expand_ordinal, text) - text = re.sub(_number_re, _expand_number, text) - return text - - -# Convert numbers to Chinese pronunciation -def number_to_chinese(text): - text = cn2an.transform(text, "an2cn") - return text - - -def map_punctuations(text): - text = text.replace(",", ",") - text = text.replace("。", ".") - text = text.replace("!", "!") - text = text.replace("?", "?") - text = text.replace(";", ";") - text = text.replace(":", ":") - text = text.replace("、", ",") - text = text.replace("‘", "'") - text = text.replace("“", '"') - text = text.replace("”", '"') - text = text.replace("’", "'") - text = text.replace("⋯", "…") - text = text.replace("···", "…") - text = text.replace("・・・", "…") - text = text.replace("...", "…") - return text - - -def is_chinese(char): - if char >= "\u4e00" and char <= "\u9fa5": - return True - else: - return False - - -def is_alphabet(char): - if (char >= "\u0041" and char <= "\u005a") or ( - char >= "\u0061" and char <= "\u007a" - ): - return True - else: - return False - - -def is_hangul(char): - letters = unicodedata.normalize("NFD", char) - return all( - ["\u1100" <= c <= "\u11ff" or "\u3131" <= c <= "\u318e" for c in letters] - ) - - -def is_japanese(char): - return any( - [ - start <= char <= end - for start, end in [ - ("\u3041", "\u3096"), - ("\u30a0", "\u30ff"), - ("\uff5f", "\uff9f"), - ("\u31f0", "\u31ff"), - ("\u3220", "\u3243"), - ("\u3280", "\u337f"), - ] - ] - ) - - -def get_segment(text: str) -> List[str]: - # sentence --> [ch_part, en_part, ch_part, ...] - # example : - # input : 我们是小米人,是吗? Yes I think so!霍...啦啦啦 - # output : [('我们是小米人,是吗? ', 'zh'), ('Yes I think so!', 'en'), ('霍...啦啦啦', 'zh')] - segments = [] - types = [] - flag = 0 - temp_seg = "" - temp_lang = "" - - for i, ch in enumerate(text): - if is_chinese(ch): - types.append("zh") - elif is_alphabet(ch): - types.append("en") - else: - types.append("other") - - assert len(types) == len(text) - - for i in range(len(types)): - # find the first char of the seg - if flag == 0: - temp_seg += text[i] - temp_lang = types[i] - flag = 1 - else: - if temp_lang == "other": - if types[i] == temp_lang: - temp_seg += text[i] - else: - temp_seg += text[i] - temp_lang = types[i] - else: - if types[i] == temp_lang: - temp_seg += text[i] - elif types[i] == "other": - temp_seg += text[i] - else: - segments.append((temp_seg, temp_lang)) - temp_seg = text[i] - temp_lang = types[i] - flag = 1 - - segments.append((temp_seg, temp_lang)) - return segments - - -def preprocess(text: str) -> str: - text = map_punctuations(text) - return text - - -def tokenize_ZH(text: str) -> List[str]: - try: - text = number_to_chinese(text) - segs = list(jieba.cut(text)) - full = lazy_pinyin( - segs, style=Style.TONE3, tone_sandhi=True, neutral_tone_with_five=True - ) - phones = [] - for x in full: - # valid pinyin (in tone3 style) is alphabet + 1 number in [1-5]. - if not (x[0:-1].isalpha() and x[-1] in ("1", "2", "3", "4", "5")): - phones.append(x) - continue - initial = to_initials(x, strict=False) - # don't want to share tokens with espeak tokens, so use tone3 style - final = to_finals_tone3(x, strict=False, neutral_tone_with_five=True) - if initial != "": - # don't want to share tokens with espeak tokens, so add a '0' after each initial - phones.append(initial + "0") - if final != "": - phones.append(final) - return phones - except Exception as ex: - logging.warning(f"Tokenize ZH failed: {ex}") - return [] - - -def tokenize_EN(text: str) -> List[str]: - try: - text = expand_abbreviations(text) - text = normalize_numbers(text) - tokens = phonemize_espeak(text, "en-us") - tokens = reduce(lambda x, y: x + y, tokens) - return tokens - except Exception as ex: - logging.warning(f"Tokenize EN failed: {ex}") - return [] - - -class TokenizerEmilia(object): - def __init__(self, token_file: Optional[str] = None, token_type="phone"): - """ - Args: - tokens: the file that contains information that maps tokens to ids, - which is a text file with '{token} {token_id}' per line. - """ - assert ( - token_type == "phone" - ), f"Only support phone tokenizer for Emilia, but get {token_type}." - self.has_tokens = False - if token_file is None: - logging.debug( - "Initialize Tokenizer without tokens file, will fail when map to ids." - ) - return - self.token2id: Dict[str, int] = {} - with open(token_file, "r", encoding="utf-8") as f: - for line in f.readlines(): - info = line.rstrip().split("\t") - token, id = info[0], int(info[1]) - assert token not in self.token2id, token - self.token2id[token] = id - self.pad_id = self.token2id["_"] # padding - - self.vocab_size = len(self.token2id) - self.has_tokens = True - - def texts_to_token_ids( - self, - texts: List[str], - ) -> List[List[int]]: - return self.tokens_to_token_ids(self.texts_to_tokens(texts)) - - def texts_to_tokens( - self, - texts: List[str], - ) -> List[List[str]]: - """ - Args: - texts: - A list of transcripts. - Returns: - Return a list of a list of tokens [utterance][token] - """ - for i in range(len(texts)): - # Text normalization - texts[i] = preprocess(texts[i]) - - phoneme_list = [] - for text in texts: - # now only en and ch - segments = get_segment(text) - all_phoneme = [] - for index in range(len(segments)): - seg = segments[index] - if seg[1] == "zh": - phoneme = tokenize_ZH(seg[0]) - else: - if seg[1] != "en": - logging.warning( - f"The lang should be en, given {seg[1]}, skipping segment : {seg}" - ) - continue - phoneme = tokenize_EN(seg[0]) - all_phoneme += phoneme - phoneme_list.append(all_phoneme) - return phoneme_list - - def tokens_to_token_ids( - self, - tokens: List[List[str]], - intersperse_blank: bool = False, - ) -> List[List[int]]: - """ - Args: - tokens_list: - A list of token list, each corresponding to one utterance. - intersperse_blank: - Whether to intersperse blanks in the token sequence. - - Returns: - Return a list of token id list [utterance][token_id] - """ - assert self.has_tokens, "Please initialize Tokenizer with a tokens file." - token_ids = [] - - for tks in tokens: - ids = [] - for t in tks: - if t not in self.token2id: - logging.warning(f"Skip OOV {t}") - continue - ids.append(self.token2id[t]) - - if intersperse_blank: - ids = intersperse(ids, self.pad_id) - - token_ids.append(ids) - - return token_ids - - -class TokenizerLibriTTS(object): - def __init__(self, token_file: str, token_type: str): - """ - Args: - type: the type of tokenizer, e.g., bpe, char, phone. - tokens: the file that contains information that maps tokens to ids, - which is a text file with '{token} {token_id}' per line if type is - char or phone, otherwise it is a bpe_model file. - """ - self.type = token_type - assert token_type in ["bpe", "char", "phone"] - # Parse token file - - if token_type == "bpe": - import sentencepiece as spm - - self.sp = spm.SentencePieceProcessor() - self.sp.load(token_file) - self.pad_id = self.sp.piece_to_id("") - self.vocab_size = self.sp.get_piece_size() - else: - self.token2id: Dict[str, int] = {} - with open(token_file, "r", encoding="utf-8") as f: - for line in f.readlines(): - info = line.rstrip().split("\t") - token, id = info[0], int(info[1]) - assert token not in self.token2id, token - self.token2id[token] = id - self.pad_id = self.token2id["_"] # padding - self.vocab_size = len(self.token2id) - try: - from tacotron_cleaner.cleaners import custom_english_cleaners as cleaner - except Exception as ex: - raise RuntimeError( - f"{ex}\nPlease run\n" - "pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/" - ) - self.cleaner = cleaner - - def texts_to_token_ids( - self, - texts: List[str], - lang: str = "en-us", - ) -> List[List[int]]: - """ - Args: - texts: - A list of transcripts. - intersperse_blank: - Whether to intersperse blanks in the token sequence. - Used when alignment is from MAS. - lang: - Language argument passed to phonemize_espeak(). - - Returns: - Return a list of token id list [utterance][token_id] - """ - for i in range(len(texts)): - # Text normalization - texts[i] = self.cleaner(texts[i]) - - if self.type == "bpe": - token_ids_list = self.sp.encode(texts) - - elif self.type == "phone": - token_ids_list = [] - for text in texts: - tokens_list = phonemize_espeak(text.lower(), lang) - tokens = [] - for t in tokens_list: - tokens.extend(t) - token_ids = [] - for t in tokens: - if t not in self.token2id: - logging.warning(f"Skip OOV {t}") - continue - token_ids.append(self.token2id[t]) - - token_ids_list.append(token_ids) - else: - token_ids_list = [] - for text in texts: - token_ids = [] - for t in text: - if t not in self.token2id: - logging.warning(f"Skip OOV {t}") - continue - token_ids.append(self.token2id[t]) - - token_ids_list.append(token_ids) - - return token_ids_list - - def tokens_to_token_ids( - self, - tokens_list: List[str], - ) -> List[List[int]]: - """ - Args: - tokens_list: - A list of token list, each corresponding to one utterance. - - Returns: - Return a list of token id list [utterance][token_id] - """ - token_ids_list = [] - - for tokens in tokens_list: - token_ids = [] - for t in tokens: - if t not in self.token2id: - logging.warning(f"Skip OOV {t}") - continue - token_ids.append(self.token2id[t]) - - token_ids_list.append(token_ids) - - return token_ids_list - - -if __name__ == "__main__": - text = "我们是5年小米人,是吗? Yes I think so! mr king, 5 years, from 2019 to 2024. 霍...啦啦啦超过90%的人咯...?!9204" - tokenizer = TokenizerEmilia() - tokens = tokenizer.texts_to_tokens([text]) - print(f"tokens : {tokens}") - tokens2 = "|".join(tokens[0]) - print(f"tokens2 : {tokens2}") - tokens2 = tokens2.split("|") - assert tokens[0] == tokens2 diff --git a/egs/zipvoice/zipvoice/train_distill.py b/egs/zipvoice/zipvoice/train_distill.py deleted file mode 100644 index 9e52a3790..000000000 --- a/egs/zipvoice/zipvoice/train_distill.py +++ /dev/null @@ -1,1043 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script trains a ZipVoice-Distill model starting from a ZipVoice model. -It has two distillation stages. - -Usage: - -(1) The first distillation stage with a fixed ZipVoice model as the teacher. -python3 zipvoice/train_distill.py \ - --world-size 8 \ - --use-fp16 1 \ - --tensorboard 1 \ - --dataset "emilia" \ - --base-lr 0.0005 \ - --max-duration 500 \ - --token-file "data/tokens_emilia.txt" \ - --manifest-dir "data/fbank" \ - --teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \ - --num-updates 60000 \ - --distill-stage "first" \ - --exp-dir zipvoice/exp_zipvoice_distill_1stage - -(2) The second distillation stage with a EMA model as the teacher. -python3 zipvoice/train_distill.py \ - --world-size 8 \ - --use-fp16 1 \ - --tensorboard 1 \ - --dataset "emilia" \ - --base-lr 0.0001 \ - --max-duration 500 \ - --token-file "data/tokens_emilia.txt" \ - --manifest-dir "data/fbank" \ - --teacher-model zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt \ - --num-updates 2000 \ - --distill-stage "second" \ - --exp-dir zipvoice/exp_zipvoice_distill -""" - -import argparse -import copy -import logging -import os -import random -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union - -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from checkpoint import load_checkpoint, save_checkpoint -from lhotse.cut import Cut, CutSet -from lhotse.utils import fix_random_seed -from model import get_distill_model, get_model -from optim import FixedLRScheduler, ScaledAdam -from tokenizer import TokenizerEmilia, TokenizerLibriTTS -from torch import Tensor -from torch.amp import GradScaler, autocast -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.utils.tensorboard import SummaryWriter -from train_flow import add_model_arguments, get_params -from tts_datamodule import TtsDataModule -from utils import ( - condition_time_mask, - get_adjusted_batch_count, - prepare_input, - set_batch_count, -) - -from icefall import diagnostics -from icefall.checkpoint import ( - remove_checkpoints, - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - make_pad_mask, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=1, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--num-updates", - type=int, - default=0, - help="Number of updates to train, will ignore num_epochs if > 0.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--teacher-model", - type=str, - help="""Checkpoints of pre-trained teacher model""", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipvoice/exp_zipvoice_distill", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.001, help="The base learning rate." - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=50, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--feat-scale", - type=float, - default=0.1, - help="The scale factor of fbank feature", - ) - - parser.add_argument( - "--ema-decay", - type=float, - default=0.9999, - help="The EMA decay factor of target model in distillation.", - ) - parser.add_argument( - "--distill-stage", - type=str, - choices=["first", "second"], - help="The stage of distillation.", - ) - - parser.add_argument( - "--dataset", - type=str, - default="emilia", - choices=["emilia", "libritts"], - help="The used training dataset", - ) - - add_model_arguments(parser) - - return parser - - -def ema(new_model, ema_model, decay): - if isinstance(new_model, DDP): - new_model = new_model.module - if isinstance(ema_model, DDP): - ema_model = ema_model.module - new_model_dict = new_model.state_dict() - ema_model_dict = ema_model.state_dict() - for key in new_model_dict.keys(): - ema_model_dict[key].data.copy_( - ema_model_dict[key].data * decay + new_model_dict[key].data * (1 - decay) - ) - - -def resume_checkpoint( - params: AttributeDict, model: nn.Module, model_avg: nn.Module, model_ema: nn.Module -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - Returns: - Return a dict containing previously saved training info. - """ - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, model=model, model_avg=model_avg, model_ema=model_ema, strict=True - ) - - if params.start_epoch > 1: - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def compute_fbank_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - teacher_model: Union[nn.Module, DDP], - features: Tensor, - features_lens: Tensor, - tokens: List[List[int]], - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. - teacher_model: - The teacher model for distillation. - features: - The target acoustic feature. - features_lens: - The number of frames of each utterance. - tokens: - Input tokens that representing the transcripts. - durations: - Duration of each token. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - batch_size, num_frames, _ = features.shape - - features = torch.nn.functional.pad( - features, (0, 0, 0, num_frames - features.size(1)) - ) # (B, T, F) - noise = torch.randn_like(features) # (B, T, F) - - # Sampling t and guidance_scale from uniform distribution - - t_value = random.random() - t = torch.ones(batch_size, 1, 1, device=device) * t_value - if params.distill_stage == "first": - guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2 - else: - guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2 + 1 - xt = features * t + noise * (1 - t) - t_delta_fix = random.uniform(0.0, min(0.3, 1 - t_value)) - t_delta_ema = random.uniform(0.0, min(0.3, 1 - t_value - t_delta_fix)) - t_dest = t + t_delta_fix + t_delta_ema - - with torch.no_grad(): - speech_condition_mask = condition_time_mask( - features_lens=features_lens, - mask_percent=(0.7, 1.0), - max_len=features.size(1), - ) - - if params.distill_stage == "first": - teacher_x_t_mid, _ = teacher_model.sample_intermediate( - tokens=tokens, - features=features, - features_lens=features_lens, - noise=xt, - speech_condition_mask=speech_condition_mask, - t_start=t, - t_end=t + t_delta_fix, - num_step=1, - guidance_scale=guidance_scale, - ) - - target_x1, _ = teacher_model.sample_intermediate( - tokens=tokens, - features=features, - features_lens=features_lens, - noise=teacher_x_t_mid, - speech_condition_mask=speech_condition_mask, - t_start=t + t_delta_fix, - t_end=t_dest, - num_step=1, - guidance_scale=guidance_scale, - ) - else: - teacher_x_t_mid, _ = teacher_model( - tokens=tokens, - features=features, - features_lens=features_lens, - noise=xt, - speech_condition_mask=speech_condition_mask, - t_start=t, - t_end=t + t_delta_fix, - num_step=1, - guidance_scale=guidance_scale, - ) - - target_x1, _ = teacher_model( - tokens=tokens, - features=features, - features_lens=features_lens, - noise=teacher_x_t_mid, - speech_condition_mask=speech_condition_mask, - t_start=t + t_delta_fix, - t_end=t_dest, - num_step=1, - guidance_scale=guidance_scale, - ) - - with torch.set_grad_enabled(is_training): - - pred_x1, _ = model( - tokens=tokens, - features=features, - features_lens=features_lens, - noise=xt, - speech_condition_mask=speech_condition_mask, - t_start=t, - t_end=t_dest, - num_step=1, - guidance_scale=guidance_scale, - ) - pred_v = (pred_x1 - xt) / (t_dest - t) - - padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T) - loss_mask = speech_condition_mask & (~padding_mask) - - target_v = (target_x1 - xt) / (t_dest - t) - loss = torch.mean((pred_v[loss_mask] - target_v[loss_mask]) ** 2) - - ut = features - noise # (B, T, F) - - ref_loss = torch.mean((pred_v[loss_mask] - ut[loss_mask]) ** 2) - - assert loss.requires_grad == is_training - info = MetricsTracker() - num_frames = features_lens.sum().item() - info["frames"] = num_frames - info["loss"] = loss.detach().cpu().item() * num_frames - info["ref_loss"] = ref_loss.detach().cpu().item() * num_frames - return loss, info - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - teacher_model: Union[nn.Module, DDP], - tokenizer: TokenizerEmilia, - optimizer: Optimizer, - scheduler: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - teacher_model: - The model for distillation. - tokenizer: - Used to convert text to tokens. - optimizer: - The optimizer. - scheduler: - The learning rate scheduler, we call step() every epoch. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to track the stats over iterations in one epoch - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - model_ema=teacher_model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params) + 100000) - - if ( - params.valid_interval is None - and batch_idx == 0 - and not params.print_diagnostics - ) or ( - params.valid_interval is not None - and params.batch_idx_train % params.valid_interval == 0 - and not params.print_diagnostics - ): - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - teacher_model=teacher_model, - tokenizer=tokenizer, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - params.batch_idx_train += 1 - - batch_size = len(batch["text"]) - - tokens, features, features_lens = prepare_input( - params=params, - batch=batch, - device=device, - tokenizer=tokenizer, - return_tokens=True, - return_feature=True, - ) - - try: - with autocast("cuda", enabled=params.use_fp16): - loss, loss_info = compute_fbank_loss( - params=params, - model=model, - teacher_model=teacher_model, - features=features, - features_lens=features_lens, - tokens=tokens, - is_training=True, - ) - - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - scaler.scale(loss).backward() - - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - if params.distill_stage == "second": - ema(model, teacher_model, params.ema_decay) - except RuntimeError as e: - if "out of memory" in str(e): - logging.info(f"out of memory error at rank {rank}") - # optimizer.zero_grad() - # duration_optimizer.zero_grad() - torch.cuda.empty_cache() - raise - continue - else: - logging.info(f"Caught exception : {e}.") - save_bad_model() - raise - except Exception as e: - logging.info(f"Caught exception : {e}.") - save_bad_model() - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - if ( - params.batch_idx_train > 0 - and params.num_updates > 0 - and params.batch_idx_train > params.num_updates - ): - break - if params.batch_idx_train % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 1024.0 or ( - cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0 - ): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if params.batch_idx_train % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " - f"loss[{loss_info}], tot_loss[{tot_loss}], " - f"cur_lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - loss_value = tot_loss["loss"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - teacher_model: Optional[nn.Module], - tokenizer: TokenizerEmilia, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - - model.eval() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to summary the stats over iterations - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - tokens, features, features_lens = prepare_input( - params=params, - batch=batch, - device=device, - tokenizer=tokenizer, - return_tokens=True, - return_feature=True, - ) - - loss, loss_info = compute_fbank_loss( - params=params, - model=model, - teacher_model=teacher_model, - features=features, - features_lens=features_lens, - tokens=tokens, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - os.makedirs(f"{params.exp_dir}/fbank", exist_ok=True) - - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - if params.dataset == "emilia": - tokenizer = TokenizerEmilia( - token_file=params.token_file, token_type=params.token_type - ) - elif params.dataset == "libritts": - tokenizer = TokenizerLibriTTS( - token_file=params.token_file, token_type=params.token_type - ) - - params.vocab_size = tokenizer.vocab_size - params.pad_id = tokenizer.pad_id - - params.device = device - - logging.info(params) - - logging.info("About to create model") - - assert params.teacher_model is not None - logging.info(f"Loading pre-trained model from {params.teacher_model}") - model = get_distill_model(params) - _ = load_checkpoint( - filename=params.teacher_model, - model=model, - strict=(params.distill_stage == "second"), - ) - - if params.distill_stage == "first": - teacher_model = get_model(params) - _ = load_checkpoint( - filename=params.teacher_model, model=teacher_model, strict=True - ) - else: - teacher_model = copy.deepcopy(model) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of parameters : {num_param}") - - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - assert params.start_epoch > 0, params.start_epoch - if params.start_epoch > 1: - logging.info(f"Resuming from epoch {params.start_epoch}") - if params.distill_stage == "first": - checkpoints = resume_checkpoint( - params=params, model=model, model_avg=model_avg - ) - else: - checkpoints = resume_checkpoint( - params=params, model=model, model_avg=model_avg, model_ema=teacher_model - ) - - model = model.to(device) - teacher_model.to(device) - teacher_model.eval() - - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - # only update the fm_decoder - num_trainable = 0 - for name, p in model.named_parameters(): - if "fm_decoder" in name: - p.requires_grad = True - num_trainable += p.numel() - else: - p.requires_grad = False - - logging.info( - "A total of {} trainable parameters ({:.3f}% of the whole model)".format( - num_trainable, num_trainable / num_param * 100 - ) - ) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs( - model, - lr=params.base_lr, - include_names=True, - ), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - scheduler = FixedLRScheduler(optimizer) - - scaler = GradScaler("cuda", enabled=params.use_fp16) - - if params.start_epoch > 1 and checkpoints is not None: - # load state_dict for optimizers - if "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - # load state_dict for schedulers - if "scheduler" in checkpoints: - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_and_long_utt_emilia(c: Cut): - if c.duration < 1.0 or c.duration > 30.0: - return False - return True - - def remove_short_and_long_utt_libritts(c: Cut): - if c.duration < 1.0 or c.duration > 20.0: - return False - return True - - datamodule = TtsDataModule(args) - if params.dataset == "emilia": - train_cuts = CutSet.mux( - datamodule.train_emilia_EN_cuts(), - datamodule.train_emilia_ZH_cuts(), - weights=[46000, 49000], - ) - train_cuts = train_cuts.filter(remove_short_and_long_utt_emilia) - dev_cuts = CutSet.mux( - datamodule.dev_emilia_EN_cuts(), - datamodule.dev_emilia_ZH_cuts(), - weights=[0.5, 0.5], - ) - elif params.dataset == "libritts": - train_cuts = datamodule.train_libritts_cuts() - train_cuts = train_cuts.filter(remove_short_and_long_utt_libritts) - dev_cuts = datamodule.dev_libritts_cuts() - - train_dl = datamodule.train_dataloaders(train_cuts) - - valid_dl = datamodule.dev_dataloaders(dev_cuts) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - logging.info(f"Start epoch {epoch}") - - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - params.cur_epoch = epoch - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - teacher_model=teacher_model, - tokenizer=tokenizer, - optimizer=optimizer, - scheduler=scheduler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint( - filename=filename, - params=params, - model=model, - model_avg=model_avg, - model_ema=teacher_model, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - if rank == 0: - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - TtsDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - main() diff --git a/egs/zipvoice/zipvoice/train_flow.py b/egs/zipvoice/zipvoice/train_flow.py deleted file mode 100644 index 0bf023273..000000000 --- a/egs/zipvoice/zipvoice/train_flow.py +++ /dev/null @@ -1,1108 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Wei Kang, -# Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script trains a ZipVoice model with the flow-matching loss. - -Usage: - -python3 zipvoice/train_flow.py \ - --world-size 8 \ - --use-fp16 1 \ - --dataset emilia \ - --max-duration 500 \ - --lr-hours 30000 \ - --lr-batches 7500 \ - --token-file "data/tokens_emilia.txt" \ - --manifest-dir "data/fbank" \ - --num-epochs 11 \ - --exp-dir zipvoice/exp_zipvoice -""" - -import argparse -import copy -import logging -import os -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union - -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from checkpoint import load_checkpoint, save_checkpoint -from lhotse.cut import Cut, CutSet -from lhotse.utils import fix_random_seed -from model import get_model -from optim import Eden, ScaledAdam -from tokenizer import TokenizerEmilia, TokenizerLibriTTS -from torch import Tensor -from torch.amp import GradScaler, autocast -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.utils.tensorboard import SummaryWriter -from tts_datamodule import TtsDataModule -from utils import get_adjusted_batch_count, prepare_input, set_batch_count - -from icefall import diagnostics -from icefall.checkpoint import ( - remove_checkpoints, - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_parameter_groups_with_lrs, - setup_logger, - str2bool, -) - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--fm-decoder-downsampling-factor", - type=str, - default="1,2,4,2,1", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--fm-decoder-num-layers", - type=str, - default="2,2,4,4,4", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--fm-decoder-cnn-module-kernel", - type=str, - default="31,15,7,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--fm-decoder-feedforward-dim", - type=int, - default=1536, - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--fm-decoder-num-heads", - type=int, - default=4, - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--fm-decoder-dim", - type=int, - default=512, - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--text-encoder-downsampling-factor", - type=str, - default="1", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--text-encoder-num-layers", - type=str, - default="4", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--text-encoder-feedforward-dim", - type=int, - default=512, - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--text-encoder-cnn-module-kernel", - type=str, - default="9", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--text-encoder-num-heads", - type=int, - default=4, - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--text-encoder-dim", - type=int, - default=192, - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=int, - default=32, - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=int, - default=12, - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=int, - default=4, - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default=48, - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--time-embed-dim", - type=int, - default=192, - help="Embedding dimension of timestamps embedding.", - ) - - parser.add_argument( - "--text-embed-dim", - type=int, - default=192, - help="Embedding dimension of text embedding.", - ) - - parser.add_argument( - "--token-type", - type=str, - default="phone", - choices=["phone", "char", "bpe"], - help="Input token type of TTS model, by default, " - "we use phone for emilia, char for libritts.", - ) - - parser.add_argument( - "--token-file", - type=str, - default="data/tokens_emilia.txt", - help="The file that contains information that maps tokens to ids," - "which is a text file with '{token}\t{token_id}' per line if type is" - "char or phone, otherwise it is a bpe_model file.", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=11, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--checkpoint", - type=str, - default=None, - help="""Checkpoints of pre-trained models, will load it if not None - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipvoice/exp_zipvoice", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.02, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=10, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--lr-hours", - type=float, - default=0, - help="""If positive, --epoch is ignored and it specifies the number of hours - that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=50, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=True, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--feat-scale", - type=float, - default=0.1, - help="The scale factor of fbank feature", - ) - - parser.add_argument( - "--condition-drop-ratio", - type=float, - default=0.2, - help="The drop rate of text condition during training.", - ) - - parser.add_argument( - "--dataset", - type=str, - default="emilia", - choices=["emilia", "libritts"], - help="The used training dataset", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - sampling_rate: Sampling rate of the wavform. - - - frame_shift_ms: Frame shift in milliseconds. - - - feat_dim: The model input dim. It has to match the one used - in computing features. - - - env_info: A dict containing information about the environment. - - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 4000, - "sampling_rate": 24000, - "frame_shift_ms": 256 / 24000 * 1000, - "feat_dim": 100, - "env_info": get_env_info(), - } - ) - - return params - - -def resume_checkpoint( - params: AttributeDict, model: nn.Module, model_avg: nn.Module -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - logging.info(f"Resuming from file {filename}") - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, model=model, model_avg=model_avg, strict=True - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def compute_fbank_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - features: Tensor, - features_lens: Tensor, - tokens: List[List[int]], - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. - features: - The target acoustic feature. - features_lens: - The number of frames of each utterance. - tokens: - Input tokens that representing the transcripts. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - batch_size, num_frames, _ = features.shape - - features = torch.nn.functional.pad( - features, (0, 0, 0, num_frames - features.size(1)) - ) # (B, T, F) - noise = torch.randn_like(features) # (B, T, F) - - # Sampling t from uniform distribution - if is_training: - t = torch.rand(batch_size, 1, 1, device=device) - else: - t = ( - (torch.arange(batch_size, device=device) / batch_size) - .unsqueeze(1) - .unsqueeze(2) - ) - with torch.set_grad_enabled(is_training): - - loss = model( - tokens=tokens, - features=features, - features_lens=features_lens, - noise=noise, - t=t, - condition_drop_ratio=params.condition_drop_ratio, - ) - - assert loss.requires_grad == is_training - info = MetricsTracker() - num_frames = features_lens.sum().item() - info["frames"] = num_frames - info["loss"] = loss.detach().cpu().item() * num_frames - - return loss, info - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: TokenizerEmilia, - optimizer: Optimizer, - scheduler: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - tokenizer: - Used to convert text to tokens. - optimizer: - The optimizer. - scheduler: - The learning rate scheduler, we call step() every epoch. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to track the stats over iterations in one epoch - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - - if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - - if ( - params.valid_interval is None - and batch_idx == 0 - and not params.print_diagnostics - ) or ( - params.valid_interval is not None - and params.batch_idx_train % params.valid_interval == 0 - and not params.print_diagnostics - ): - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - tokenizer=tokenizer, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - params.batch_idx_train += 1 - - batch_size = len(batch["text"]) - - tokens, features, features_lens = prepare_input( - params=params, - batch=batch, - device=device, - tokenizer=tokenizer, - return_tokens=True, - return_feature=True, - ) - - try: - with autocast("cuda", enabled=params.use_fp16): - loss, loss_info = compute_fbank_loss( - params=params, - model=model, - features=features, - features_lens=features_lens, - tokens=tokens, - is_training=True, - ) - - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - scaler.scale(loss).backward() - - scheduler.step_batch(params.batch_idx_train) - # Use the number of hours of speech to adjust the learning rate - if params.lr_hours > 0: - scheduler.step_epoch( - params.batch_idx_train - * params.max_duration - * params.world_size - / 3600 - ) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except RuntimeError as e: - if "out of memory" in str(e): - logging.info(f"out of memory error at rank {rank}") - # optimizer.zero_grad() - # duration_optimizer.zero_grad() - torch.cuda.empty_cache() - raise - continue - else: - logging.info(f"Caught exception : {e}.") - save_bad_model() - raise - except Exception as e: - logging.info(f"Caught exception : {e}.") - save_bad_model() - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - if params.batch_idx_train % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 1024.0 or ( - cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0 - ): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if params.batch_idx_train % params.log_interval == 0: - cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " - f"loss[{loss_info}], tot_loss[{tot_loss}], " - f"cur_lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - loss_value = tot_loss["loss"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: TokenizerEmilia, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - - model.eval() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to summary the stats over iterations - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - tokens, features, features_lens = prepare_input( - params=params, - batch=batch, - device=device, - tokenizer=tokenizer, - return_tokens=True, - return_feature=True, - ) - - loss, loss_info = compute_fbank_loss( - params=params, - model=model, - features=features, - features_lens=features_lens, - tokens=tokens, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - os.makedirs(f"{params.exp_dir}/fbank", exist_ok=True) - - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - if params.dataset == "emilia": - tokenizer = TokenizerEmilia( - token_file=params.token_file, token_type=params.token_type - ) - elif params.dataset == "libritts": - tokenizer = TokenizerLibriTTS( - token_file=params.token_file, token_type=params.token_type - ) - params.vocab_size = tokenizer.vocab_size - params.pad_id = tokenizer.pad_id - - params.device = device - - logging.info(params) - - logging.info("About to create model") - - model = get_model(params) - if params.checkpoint is not None: - logging.info(f"Loading pre-trained model from {params.checkpoint}") - _ = load_checkpoint(filename=params.checkpoint, model=model, strict=True) - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of parameters : {num_param}") - - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - assert params.start_epoch > 0, params.start_epoch - if params.start_epoch > 1: - checkpoints = resume_checkpoint(params=params, model=model, model_avg=model_avg) - - model = model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer = ScaledAdam( - get_parameter_groups_with_lrs( - model, - lr=params.base_lr, - include_names=True, - ), - lr=params.base_lr, # should have no effect - clipping_scale=2.0, - ) - - assert params.lr_hours >= 0 - if params.lr_hours > 0: - scheduler = Eden(optimizer, params.lr_batches, params.lr_hours) - else: - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - scaler = GradScaler("cuda", enabled=params.use_fp16) - - if params.start_epoch > 1 and checkpoints is not None: - # load state_dict for optimizers - if "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - # load state_dict for schedulers - if "scheduler" in checkpoints: - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_and_long_utt_emilia(c: Cut): - if c.duration < 1.0 or c.duration > 30.0: - return False - return True - - def remove_short_and_long_utt_libritts(c: Cut): - if c.duration < 1.0 or c.duration > 20.0: - return False - return True - - datamodule = TtsDataModule(args) - if params.dataset == "emilia": - train_cuts = CutSet.mux( - datamodule.train_emilia_EN_cuts(), - datamodule.train_emilia_ZH_cuts(), - weights=[46000, 49000], - ) - train_cuts = train_cuts.filter(remove_short_and_long_utt_emilia) - dev_cuts = CutSet.mux( - datamodule.dev_emilia_EN_cuts(), - datamodule.dev_emilia_ZH_cuts(), - weights=[0.5, 0.5], - ) - elif params.dataset == "libritts": - train_cuts = datamodule.train_libritts_cuts() - train_cuts = train_cuts.filter(remove_short_and_long_utt_libritts) - dev_cuts = datamodule.dev_libritts_cuts() - - train_dl = datamodule.train_dataloaders(train_cuts) - - valid_dl = datamodule.dev_dataloaders(dev_cuts) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - logging.info(f"Start epoch {epoch}") - - if params.lr_hours == 0: - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - params.cur_epoch = epoch - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - tokenizer=tokenizer, - optimizer=optimizer, - scheduler=scheduler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint( - filename=filename, - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - if rank == 0: - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - TtsDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - main() diff --git a/egs/zipvoice/zipvoice/tts_datamodule.py b/egs/zipvoice/zipvoice/tts_datamodule.py deleted file mode 100644 index 972c700f7..000000000 --- a/egs/zipvoice/zipvoice/tts_datamodule.py +++ /dev/null @@ -1,456 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, -# Zengwei Yao, -# Zengrui Jin, -# Han Zhu, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Union - -import torch -from feature import TorchAudioFbank, TorchAudioFbankConfig -from lhotse import CutSet, load_manifest_lazy, validate -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - DynamicBucketingSampler, - PrecomputedFeatures, - SimpleCutSampler, -) -from lhotse.dataset.collation import collate_audio -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - BatchIO, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed, ifnone -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -SAMPLING_RATE = 24000 - - -class TtsDataModule: - """ - DataModule for tts experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="TTS data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank_emilia"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=100, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['cut'] with the cuts that " - "were used to construct it.", - ) - group.add_argument( - "--num-workers", - type=int, - default=8, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to create train dataset") - train = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - sampling_rate = SAMPLING_RATE - config = TorchAudioFbankConfig( - sampling_rate=sampling_rate, - n_mels=100, - n_fft=1024, - hop_length=256, - ) - train = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)), - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - sampling_rate = SAMPLING_RATE - config = TorchAudioFbankConfig( - sampling_rate=sampling_rate, - n_mels=100, - n_fft=1024, - hop_length=256, - ) - validate = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)), - return_cuts=self.args.return_cuts, - ) - else: - validate = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - dev_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create valid dataloader") - dev_dl = DataLoader( - validate, - sampler=dev_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return dev_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - if self.args.on_the_fly_feats: - sampling_rate = SAMPLING_RATE - config = TorchAudioFbankConfig( - sampling_rate=sampling_rate, - n_mels=100, - n_fft=1024, - hop_length=256, - ) - test = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)), - return_cuts=self.args.return_cuts, - return_audio=True, - ) - else: - test = SpeechSynthesisDataset( - return_text=True, - return_tokens=True, - return_spk_ids=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - return_audio=True, - ) - test_sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=test_sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_emilia_EN_cuts(self) -> CutSet: - logging.info("About to get train the EN subset") - return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_EN.jsonl.gz") - - @lru_cache() - def train_emilia_ZH_cuts(self) -> CutSet: - logging.info("About to get train the ZH subset") - return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_ZH.jsonl.gz") - - @lru_cache() - def dev_emilia_EN_cuts(self) -> CutSet: - logging.info("About to get dev the EN subset") - return load_manifest_lazy( - self.args.manifest_dir / "emilia_cuts_EN-dev.jsonl.gz" - ) - - @lru_cache() - def dev_emilia_ZH_cuts(self) -> CutSet: - logging.info("About to get dev the ZH subset") - return load_manifest_lazy( - self.args.manifest_dir / "emilia_cuts_ZH-dev.jsonl.gz" - ) - - @lru_cache() - def train_libritts_cuts(self) -> CutSet: - logging.info( - "About to get the shuffled train-clean-100, \ - train-clean-360 and train-other-500 cuts" - ) - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz" - ) - - @lru_cache() - def dev_libritts_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz" - ) - - -class SpeechSynthesisDataset(torch.utils.data.Dataset): - """ - The PyTorch Dataset for the speech synthesis task. - Each item in this dataset is a dict of: - - .. code-block:: - - { - 'audio': (B x NumSamples) float tensor - 'features': (B x NumFrames x NumFeatures) float tensor - 'audio_lens': (B, ) int tensor - 'features_lens': (B, ) int tensor - 'text': List[str] of len B # when return_text=True - 'tokens': List[List[str]] # when return_tokens=True - 'speakers': List[str] of len B # when return_spk_ids=True - 'cut': List of Cuts # when return_cuts=True - } - """ - - def __init__( - self, - cut_transforms: List[Callable[[CutSet], CutSet]] = None, - feature_input_strategy: BatchIO = PrecomputedFeatures(), - feature_transforms: Union[Sequence[Callable], Callable] = None, - return_text: bool = True, - return_tokens: bool = False, - return_spk_ids: bool = False, - return_cuts: bool = False, - return_audio: bool = False, - ) -> None: - super().__init__() - - self.cut_transforms = ifnone(cut_transforms, []) - self.feature_input_strategy = feature_input_strategy - - self.return_text = return_text - self.return_tokens = return_tokens - self.return_spk_ids = return_spk_ids - self.return_cuts = return_cuts - self.return_audio = return_audio - - if feature_transforms is None: - feature_transforms = [] - elif not isinstance(feature_transforms, Sequence): - feature_transforms = [feature_transforms] - - assert all( - isinstance(transform, Callable) for transform in feature_transforms - ), "Feature transforms must be Callable" - self.feature_transforms = feature_transforms - - def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: - validate_for_tts(cuts) - - for transform in self.cut_transforms: - cuts = transform(cuts) - - features, features_lens = self.feature_input_strategy(cuts) - - for transform in self.feature_transforms: - features = transform(features) - - batch = { - "features": features, - "features_lens": features_lens, - } - - if self.return_audio: - audio, audio_lens = collate_audio(cuts) - batch["audio"] = audio - batch["audio_lens"] = audio_lens - - if self.return_text: - # use normalized text - text = [cut.supervisions[0].normalized_text for cut in cuts] - batch["text"] = text - - if self.return_tokens: - tokens = [cut.tokens for cut in cuts] - batch["tokens"] = tokens - - if self.return_spk_ids: - batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts] - - if self.return_cuts: - batch["cut"] = [cut for cut in cuts] - - return batch - - -def validate_for_tts(cuts: CutSet) -> None: - validate(cuts) - for cut in cuts: - assert ( - len(cut.supervisions) == 1 - ), "Only the Cuts with single supervision are supported." diff --git a/egs/zipvoice/zipvoice/utils.py b/egs/zipvoice/zipvoice/utils.py deleted file mode 100644 index 4092d0ae4..000000000 --- a/egs/zipvoice/zipvoice/utils.py +++ /dev/null @@ -1,219 +0,0 @@ -from typing import Any, List, Optional, Tuple, Union - -import torch -from torch import nn -from torch.nn.parallel import DistributedDataParallel as DDP - - -class AttributeDict(dict): - def __getattr__(self, key): - if key in self: - return self[key] - raise AttributeError(f"No such attribute '{key}'") - - def __setattr__(self, key, value): - self[key] = value - - def __delattr__(self, key): - if key in self: - del self[key] - return - raise AttributeError(f"No such attribute '{key}'") - - -def prepare_input( - params: AttributeDict, - batch: dict, - device: torch.device, - tokenizer: Optional[Any] = None, - return_tokens: bool = False, - return_feature: bool = False, - return_audio: bool = False, - return_prompt: bool = False, -): - """ - Parse the features and targets of the current batch. - Args: - params: - It is returned by :func:`get_params`. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - sp: - Used to convert text to bpe tokens. - device: - The device of Tensor. - """ - return_list = [] - - if return_tokens: - assert tokenizer is not None - - if params.token_type == "phone": - tokens = tokenizer.tokens_to_token_ids(batch["tokens"]) - else: - tokens = tokenizer.texts_to_token_ids(batch["text"]) - return_list += [tokens] - - if return_feature: - features = batch["features"].to(device) - features_lens = batch["features_lens"].to(device) - return_list += [features * params.feat_scale, features_lens] - - if return_audio: - return_list += [batch["audio"], batch["audio_lens"]] - - if return_prompt: - if return_tokens: - if params.token_type == "phone": - prompt_tokens = tokenizer.tokens_to_token_ids(batch["prompt"]["tokens"]) - else: - prompt_tokens = tokenizer.texts_to_token_ids(batch["prompt"]["text"]) - return_list += [prompt_tokens] - if return_feature: - prompt_features = batch["prompt"]["features"].to(device) - prompt_features_lens = batch["prompt"]["features_lens"].to(device) - return_list += [prompt_features * params.feat_scale, prompt_features_lens] - if return_audio: - return_list += [batch["prompt"]["audio"], batch["prompt"]["audio_lens"]] - - return return_list - - -def prepare_avg_tokens_durations(features_lens, tokens_lens): - tokens_durations = [] - for i in range(len(features_lens)): - utt_duration = features_lens[i] - avg_token_duration = utt_duration // tokens_lens[i] - tokens_durations.append([avg_token_duration] * tokens_lens[i]) - return tokens_durations - - -def pad_labels(y: List[List[int]], pad_id: int, device: torch.device): - """ - Pad the transcripts to the same length with zeros. - - Args: - y: the transcripts, which is a list of a list - - Returns: - Return a Tensor of padded transcripts. - """ - y = [l + [pad_id] for l in y] - length = max([len(l) for l in y]) - y = [l + [pad_id] * (length - len(l)) for l in y] - return torch.tensor(y, dtype=torch.int64, device=device) - - -def get_tokens_index(durations: List[List[int]], num_frames: int) -> torch.Tensor: - """ - Gets position in the transcript for each frame, i.e. the position - in the symbol-sequence to look up. - - Args: - durations: - Duration of each token in transcripts. - num_frames: - The maximum frame length of the current batch. - - Returns: - Return a Tensor of shape (batch_size, num_frames) - """ - durations = [x + [num_frames - sum(x)] for x in durations] - batch_size = len(durations) - ans = torch.zeros(batch_size, num_frames, dtype=torch.int64) - for b in range(batch_size): - this_dur = durations[b] - cur_frame = 0 - for i, d in enumerate(this_dur): - ans[b, cur_frame : cur_frame + d] = i - cur_frame += d - assert cur_frame == num_frames, (cur_frame, num_frames) - return ans - - -def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name - - -def condition_time_mask( - features_lens: torch.Tensor, mask_percent: Tuple[float, float], max_len: int = 0 -) -> torch.Tensor: - """ - Apply Time masking. - Args: - features_lens: - input tensor of shape ``(B)`` - mask_size: - the width size for masking. - max_len: - the maximum length of the mask. - Returns: - Return a 2-D bool tensor (B, T), where masked positions - are filled with `True` and non-masked positions are - filled with `False`. - """ - mask_size = ( - torch.zeros_like(features_lens, dtype=torch.float32).uniform_(*mask_percent) - * features_lens - ).to(torch.int64) - mask_starts = ( - torch.rand_like(mask_size, dtype=torch.float32) * (features_lens - mask_size) - ).to(torch.int64) - mask_ends = mask_starts + mask_size - max_len = max(max_len, features_lens.max()) - seq_range = torch.arange(0, max_len, device=features_lens.device) - mask = (seq_range[None, :] >= mask_starts[:, None]) & ( - seq_range[None, :] < mask_ends[:, None] - ) - return mask - - -def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: - """ - Args: - lengths: - A 1-D tensor containing sentence lengths. - max_len: - The length of masks. - Returns: - Return a 2-D bool tensor, where masked positions - are filled with `True` and non-masked positions are - filled with `False`. - - >>> lengths = torch.tensor([1, 3, 2, 5]) - >>> make_pad_mask(lengths) - tensor([[False, True, True, True, True], - [False, False, False, True, True], - [False, False, True, True, True], - [False, False, False, False, False]]) - """ - assert lengths.ndim == 1, lengths.ndim - max_len = max(max_len, lengths.max()) - n = lengths.size(0) - seq_range = torch.arange(0, max_len, device=lengths.device) - expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) - - return expaned_lengths >= lengths.unsqueeze(-1) diff --git a/egs/zipvoice/zipvoice/zipformer.py b/egs/zipvoice/zipvoice/zipformer.py deleted file mode 100644 index 190191cbb..000000000 --- a/egs/zipvoice/zipvoice/zipformer.py +++ /dev/null @@ -1,1648 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey, -# Zengwei Yao, -# Wei Kang -# Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import logging -import math -import random -from typing import Optional, Tuple, Union - -import torch -from scaling import ( - Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. -) -from scaling import ( - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. -) -from scaling import ( - ActivationDropoutAndLinear, - Balancer, - BiasNorm, - Dropout2, - FloatLike, - ScheduledFloat, - SwooshR, - Whiten, - limit_param_value, - penalize_abs_values_gt, - softmax, -) -from torch import Tensor, nn - - -def timestep_embedding(timesteps, dim, max_period=10000): - """Create sinusoidal timestep embeddings. - - :param timesteps: shape of (N) or (N, T) - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim) - """ - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) - / half - ) - - if timesteps.dim() == 2: - timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N) - - args = timesteps[..., None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1) - return embedding - - -class TTSZipformer(nn.Module): - """ - Args: - - Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length - as downsampling_factor if they are single ints or one-element tuples. The length of - downsampling_factor defines the number of stacks. - - downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. - Note: this is in addition to the downsampling factor of 2 that is applied in - the frontend (self.encoder_embed). - encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per - encoder stack. - num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack - query_head_dim (int or Tuple[int]): dimension of query and key per attention - head: per stack, if a tuple.. - pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per - attention head - value_head_dim (int or Tuple[int]): dimension of value in each attention head - num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. - Must be at least 4. - feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules - cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module - - pos_dim (int): the dimension of each positional-encoding vector prior to projection, - e.g. 128. - - dropout (float): dropout rate - warmup_batches (float): number of batches to warm up over; this controls - dropout of encoder layers. - use_time_embed: (bool): if True, do not take time embedding as additional input. - time_embed_dim: (int): the dimension of the time embedding. - """ - - def __init__( - self, - in_dim: int, - out_dim: int, - downsampling_factor: Tuple[int] = (2, 4), - num_encoder_layers: Union[int, Tuple[int]] = 4, - cnn_module_kernel: Union[int, Tuple[int]] = 31, - encoder_dim: int = 384, - query_head_dim: int = 24, - pos_head_dim: int = 4, - value_head_dim: int = 12, - num_heads: int = 8, - feedforward_dim: int = 1536, - pos_dim: int = 192, - dropout: FloatLike = None, # see code below for default - warmup_batches: float = 4000.0, - use_time_embed: bool = True, - time_embed_dim: int = 192, - use_guidance_scale_embed: bool = False, - guidance_scale_embed_dim: int = 192, - use_conv: bool = True, - ) -> None: - super(TTSZipformer, self).__init__() - - if dropout is None: - dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) - - def _to_tuple(x): - """Converts a single int or a 1-tuple of an int to a tuple with the same length - as downsampling_factor""" - if isinstance(x, int): - x = (x,) - if len(x) == 1: - x = x * len(downsampling_factor) - else: - assert len(x) == len(downsampling_factor) and isinstance(x[0], int) - return x - - def _assert_downsampling_factor(factors): - """assert downsampling_factor follows u-net style""" - assert factors[0] == 1 and factors[-1] == 1 - - for i in range(1, len(factors) // 2 + 1): - assert factors[i] == factors[i - 1] * 2 - - for i in range(len(factors) // 2 + 1, len(factors)): - assert factors[i] * 2 == factors[i - 1] - - _assert_downsampling_factor(downsampling_factor) - self.downsampling_factor = downsampling_factor # tuple - num_encoder_layers = _to_tuple(num_encoder_layers) - self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) - self.encoder_dim = encoder_dim - self.num_encoder_layers = num_encoder_layers - self.query_head_dim = query_head_dim - self.value_head_dim = value_head_dim - self.num_heads = num_heads - - self.use_time_embed = use_time_embed - self.use_guidance_scale_embed = use_guidance_scale_embed - - self.time_embed_dim = time_embed_dim - if self.use_time_embed: - assert time_embed_dim != -1 - else: - time_embed_dim = -1 - self.guidance_scale_embed_dim = guidance_scale_embed_dim - - self.in_proj = nn.Linear(in_dim, encoder_dim) - self.out_proj = nn.Linear(encoder_dim, out_dim) - - # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder - encoders = [] - - num_encoders = len(downsampling_factor) - for i in range(num_encoders): - encoder_layer = Zipformer2EncoderLayer( - embed_dim=encoder_dim, - pos_dim=pos_dim, - num_heads=num_heads, - query_head_dim=query_head_dim, - pos_head_dim=pos_head_dim, - value_head_dim=value_head_dim, - feedforward_dim=feedforward_dim, - use_conv=use_conv, - cnn_module_kernel=cnn_module_kernel[i], - dropout=dropout, - ) - - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. - encoder = Zipformer2Encoder( - encoder_layer, - num_encoder_layers[i], - embed_dim=encoder_dim, - time_embed_dim=time_embed_dim, - pos_dim=pos_dim, - warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), - final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), - ) - - if downsampling_factor[i] != 1: - encoder = DownsampledZipformer2Encoder( - encoder, - dim=encoder_dim, - downsample=downsampling_factor[i], - ) - - encoders.append(encoder) - - self.encoders = nn.ModuleList(encoders) - if self.use_time_embed: - self.time_embed = nn.Sequential( - nn.Linear(time_embed_dim, time_embed_dim * 2), - SwooshR(), - nn.Linear(time_embed_dim * 2, time_embed_dim), - ) - else: - self.time_embed = None - - if self.use_guidance_scale_embed: - self.guidance_scale_embed = ScaledLinear( - guidance_scale_embed_dim, time_embed_dim, bias=False, initial_scale=0.1 - ) - else: - self.guidance_scale_embed = None - - def forward( - self, - x: Tensor, - t: Optional[Tensor] = None, - padding_mask: Optional[Tensor] = None, - guidance_scale: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - t: - A t tensor of shape (batch_size,) or (batch_size, seq_len) - padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - Returns: - Return the output embeddings. its shape is (batch_size, output_seq_len, encoder_dim) - """ - x = x.permute(1, 0, 2) - x = self.in_proj(x) - - if t is not None: - assert t.dim() == 1 or t.dim() == 2, t.shape - time_emb = timestep_embedding(t, self.time_embed_dim) - if guidance_scale is not None: - assert ( - guidance_scale.dim() == 1 or guidance_scale.dim() == 2 - ), guidance_scale.shape - guidance_scale_emb = self.guidance_scale_embed( - timestep_embedding(guidance_scale, self.guidance_scale_embed_dim) - ) - time_emb = time_emb + guidance_scale_emb - time_emb = self.time_embed(time_emb) - else: - time_emb = None - - attn_mask = None - - for i, module in enumerate(self.encoders): - x = module( - x, - time_emb=time_emb, - src_key_padding_mask=padding_mask, - attn_mask=attn_mask, - ) - x = self.out_proj(x) - x = x.permute(1, 0, 2) - return x - - -def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: - return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) - - -class Zipformer2EncoderLayer(nn.Module): - """ - Args: - embed_dim: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - feedforward_dim: the dimension of the feedforward network model (required). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module (default=31). - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - value_head_dim: int, - feedforward_dim: int, - dropout: FloatLike = 0.1, - cnn_module_kernel: int = 31, - use_conv: bool = True, - attention_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 - ), - conv_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 - ), - const_attention_rate: FloatLike = ScheduledFloat( - (0.0, 0.25), (4000.0, 0.025), default=0 - ), - ff2_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) - ), - ff3_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) - ), - bypass_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.5), (4000.0, 0.02), default=0 - ), - ) -> None: - super(Zipformer2EncoderLayer, self).__init__() - self.embed_dim = embed_dim - - # self.bypass implements layer skipping as well as bypass; see its default values. - self.bypass = BypassModule( - embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 - ) - # bypass_mid is bypass used in the middle of the layer. - self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) - - # skip probability for dynamic modules (meaning: anything but feedforward). - self.attention_skip_rate = copy.deepcopy(attention_skip_rate) - # an additional skip probability that applies to ConvModule to stop it from - # contributing too much early on. - self.conv_skip_rate = copy.deepcopy(conv_skip_rate) - - # ff2_skip_rate is to prevent the ff2 module from having output that's too big - # compared to its residual. - self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) - self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) - - self.const_attention_rate = copy.deepcopy(const_attention_rate) - - self.self_attn_weights = RelPositionMultiheadAttentionWeights( - embed_dim, - pos_dim=pos_dim, - num_heads=num_heads, - query_head_dim=query_head_dim, - pos_head_dim=pos_head_dim, - dropout=0.0, - ) - - self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) - - self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) - - self.feed_forward1 = FeedforwardModule( - embed_dim, (feedforward_dim * 3) // 4, dropout - ) - - self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - - self.feed_forward3 = FeedforwardModule( - embed_dim, (feedforward_dim * 5) // 4, dropout - ) - - self.nonlin_attention = NonlinAttention( - embed_dim, hidden_channels=3 * embed_dim // 4 - ) - - self.use_conv = use_conv - - if self.use_conv: - self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel) - - self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel) - - self.norm = BiasNorm(embed_dim) - - self.balancer1 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - min_abs=0.2, - max_abs=4.0, - ) - - # balancer for output of NonlinAttentionModule - self.balancer_na = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), - prob=0.05, # out of concern for memory usage - ) - - # balancer for output of feedforward2, prevent it from staying too - # small. give this a very small probability, even at the start of - # training, it's to fix a rare problem and it's OK to fix it slowly. - self.balancer_ff2 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), - max_abs=2.0, - prob=0.05, - ) - - self.balancer_ff3 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), - max_abs=4.0, - prob=0.05, - ) - - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(4.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.balancer2 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - min_abs=0.1, - max_abs=4.0, - ) - - def get_sequence_dropout_mask( - self, x: Tensor, dropout_rate: float - ) -> Optional[Tensor]: - if ( - dropout_rate == 0.0 - or not self.training - or torch.jit.is_scripting() - or torch.jit.is_tracing() - ): - return None - batch_size = x.shape[1] - mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) - return mask - - def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: - """ - Apply sequence-level dropout to x. - x shape: (seq_len, batch_size, embed_dim) - """ - dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) - if dropout_mask is None: - return x - else: - return x * dropout_mask - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - time_emb: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) - time_emb: the embedding representing the current timestep: shape (batch_size, embedding_dim) - or (seq_len, batch_size, embedding_dim) . - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: - A tensor which has the same shape as src - """ - src_orig = src - - # dropout rate for non-feedforward submodules - if torch.jit.is_scripting() or torch.jit.is_tracing(): - attention_skip_rate = 0.0 - else: - attention_skip_rate = ( - float(self.attention_skip_rate) if self.training else 0.0 - ) - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - if time_emb is not None: - - src = src + time_emb - - src = src + self.feed_forward1(src) - - self_attn_dropout_mask = self.get_sequence_dropout_mask( - src, attention_skip_rate - ) - - selected_attn_weights = attn_weights[0:1] - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif self.training and random.random() < float(self.const_attention_rate): - # Make attention weights constant. The intention is to - # encourage these modules to do something similar to an - # averaging-over-time operation. - # only need the mask, can just use the 1st one and expand later - selected_attn_weights = selected_attn_weights[0:1] - selected_attn_weights = (selected_attn_weights > 0.0).to( - selected_attn_weights.dtype - ) - selected_attn_weights = selected_attn_weights * ( - 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) - ) - - na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) - - src = src + ( - na if self_attn_dropout_mask is None else na * self_attn_dropout_mask - ) - - self_attn = self.self_attn1(src, attn_weights) - - src = src + ( - self_attn - if self_attn_dropout_mask is None - else self_attn * self_attn_dropout_mask - ) - - if self.use_conv: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - - if time_emb is not None: - src = src + time_emb - - src = src + self.sequence_dropout( - self.conv_module1( - src, - src_key_padding_mask=src_key_padding_mask, - ), - conv_skip_rate, - ) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - ff2_skip_rate = 0.0 - else: - ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate - ) - - # bypass in the middle of the layer. - src = self.bypass_mid(src_orig, src) - - self_attn = self.self_attn2(src, attn_weights) - - src = src + ( - self_attn - if self_attn_dropout_mask is None - else self_attn * self_attn_dropout_mask - ) - - if self.use_conv: - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - - if time_emb is not None: - src = src + time_emb - - src = src + self.sequence_dropout( - self.conv_module2( - src, - src_key_padding_mask=src_key_padding_mask, - ), - conv_skip_rate, - ) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - ff3_skip_rate = 0.0 - else: - ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate - ) - - src = self.balancer1(src) - src = self.norm(src) - - src = self.bypass(src_orig, src) - - src = self.balancer2(src) - src = self.whiten(src) - - return src - - -class Zipformer2Encoder(nn.Module): - r"""Zipformer2Encoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - pos_dim: the dimension for the relative positional encoding - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = zipformer_encoder(src) - """ - - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - embed_dim: int, - time_embed_dim: int, - pos_dim: int, - warmup_begin: float, - warmup_end: float, - initial_layerdrop_rate: float = 0.5, - final_layerdrop_rate: float = 0.05, - ) -> None: - super().__init__() - self.encoder_pos = CompactRelPositionalEncoding( - pos_dim, dropout_rate=0.15, length_factor=1.0 - ) - if time_embed_dim != -1: - self.time_emb = nn.Sequential( - SwooshR(), - nn.Linear(time_embed_dim, embed_dim), - ) - else: - self.time_emb = None - - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - assert 0 <= warmup_begin <= warmup_end - - delta = (1.0 / num_layers) * (warmup_end - warmup_begin) - cur_begin = warmup_begin # interpreted as a training batch index - for i in range(num_layers): - cur_end = cur_begin + delta - self.layers[i].bypass.skip_rate = ScheduledFloat( - (cur_begin, initial_layerdrop_rate), - (cur_end, final_layerdrop_rate), - default=0.0, - ) - cur_begin = cur_end - - def forward( - self, - src: Tensor, - time_emb: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - the embedding representing the current timestep: shape (batch_size, embedding_dim) - or (seq_len, batch_size, embedding_dim) . - time_emb: the embedding representing the current timestep: shape (batch_size, embedding_dim) - or (seq_len, batch_size, embedding_dim) . - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - pos_emb = self.encoder_pos(src) - if self.time_emb is not None: - assert time_emb is not None - time_emb = self.time_emb(time_emb) - else: - assert time_emb is None - - output = src - - for i, mod in enumerate(self.layers): - output = mod( - output, - pos_emb, - time_emb=time_emb, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - - return output - - -class BypassModule(nn.Module): - """ - An nn.Module that implements a learnable bypass scale, and also randomized per-sequence - layer-skipping. The bypass is limited during early stages of training to be close to - "straight-through", i.e. to not do the bypass operation much initially, in order to - force all the modules to learn something. - """ - - def __init__( - self, - embed_dim: int, - skip_rate: FloatLike = 0.0, - straight_through_rate: FloatLike = 0.0, - scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), - scale_max: FloatLike = 1.0, - ): - super().__init__() - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - self.skip_rate = copy.deepcopy(skip_rate) - self.straight_through_rate = copy.deepcopy(straight_through_rate) - self.scale_min = copy.deepcopy(scale_min) - self.scale_max = copy.deepcopy(scale_max) - - def _get_bypass_scale(self, batch_size: int): - # returns bypass-scale of shape (num_channels,), - # or (batch_size, num_channels,). This is actually the - # scale on the non-residual term, so 0 corresponds to bypassing - # this module. - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return self.bypass_scale - else: - ans = limit_param_value( - self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) - ) - skip_rate = float(self.skip_rate) - if skip_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate - ans = ans * mask - # now ans is of shape (batch_size, num_channels), and is zero for sequences - # on which we have randomly chosen to do layer-skipping. - straight_through_rate = float(self.straight_through_rate) - if straight_through_rate != 0.0: - mask = ( - torch.rand((batch_size, 1), device=ans.device) - < straight_through_rate - ) - ans = torch.maximum(ans, mask.to(ans.dtype)) - return ans - - def forward(self, src_orig: Tensor, src: Tensor): - """ - Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) - Returns: something with the same shape as src and src_orig - """ - bypass_scale = self._get_bypass_scale(src.shape[1]) - return src_orig + (src - src_orig) * bypass_scale - - -class DownsampledZipformer2Encoder(nn.Module): - r""" - DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, - after convolutional downsampling, and then upsampled again at the output, and combined - with the origin input, so that the output has the same shape as the input. - """ - - def __init__(self, encoder: nn.Module, dim: int, downsample: int): - super(DownsampledZipformer2Encoder, self).__init__() - self.downsample_factor = downsample - self.downsample = SimpleDownsample(downsample) - self.num_layers = encoder.num_layers - self.encoder = encoder - self.upsample = SimpleUpsample(downsample) - self.out_combiner = BypassModule(dim, straight_through_rate=0) - - def forward( - self, - src: Tensor, - time_emb: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - time_emb: the embedding representing the current timestep: shape (batch_size, embedding_dim) - or (seq_len, batch_size, embedding_dim) . - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - src_orig = src - src = self.downsample(src) - ds = self.downsample_factor - if time_emb is not None and time_emb.dim() == 3: - time_emb = time_emb[::ds] - if attn_mask is not None: - attn_mask = attn_mask[::ds, ::ds] - if src_key_padding_mask is not None: - src_key_padding_mask = src_key_padding_mask[..., ::ds] - - src = self.encoder( - src, - time_emb=time_emb, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return self.out_combiner(src_orig, src) - - -class SimpleDownsample(torch.nn.Module): - """ - Does downsampling with attention, by weighted sum. - """ - - def __init__(self, downsample: int): - super(SimpleDownsample, self).__init__() - - self.bias = nn.Parameter(torch.zeros(downsample)) - - self.name = None # will be set from training code - - self.downsample = downsample - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, channels) - """ - (seq_len, batch_size, in_channels) = src.shape - ds = self.downsample - d_seq_len = (seq_len + ds - 1) // ds - - # Pad to an exact multiple of self.downsample - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds - - src = src.reshape(d_seq_len, ds, batch_size, in_channels) - - weights = self.bias.softmax(dim=0) - # weights: (downsample, 1, 1) - weights = weights.unsqueeze(-1).unsqueeze(-1) - - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) - - return ans - - -class SimpleUpsample(torch.nn.Module): - """ - A very simple form of upsampling that just repeats the input. - """ - - def __init__(self, upsample: int): - super(SimpleUpsample, self).__init__() - self.upsample = upsample - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*upsample), batch_size, num_channels) - """ - upsample = self.upsample - (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src.reshape(seq_len * upsample, batch_size, num_channels) - return src - - -class CompactRelPositionalEncoding(torch.nn.Module): - """ - Relative positional encoding module. This version is "compact" meaning it is able to encode - the important information about the relative position in a relatively small number of dimensions. - The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) - make very little difference to the embedding. Such differences were potentially important - when encoding absolute position, but not important when encoding relative position because there - is now no need to compare two large offsets with each other. - - Our embedding works by projecting the interval [-infinity,infinity] to a finite interval - using the atan() function, before doing the Fourier transform of that fixed interval. The - atan() function would compress the "long tails" too small, - making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic - function to compress large offsets to a smaller range before applying atan(). - Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long - as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) - - - Args: - embed_dim: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length: just a heuristic for initialization. - length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives - less weight to small differences of offset near the origin. - """ - - def __init__( - self, - embed_dim: int, - dropout_rate: FloatLike, - max_len: int = 1000, - length_factor: float = 1.0, - ) -> None: - """Construct a CompactRelPositionalEncoding object.""" - super(CompactRelPositionalEncoding, self).__init__() - self.embed_dim = embed_dim - assert embed_dim % 2 == 0, embed_dim - self.dropout = Dropout2(dropout_rate) - self.pe = None - assert length_factor >= 1.0, length_factor - self.length_factor = length_factor - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: - """Reset the positional encodings.""" - T = x.size(0) + left_context_len - - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(0) >= T * 2 - 1: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - - # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] - x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) - - freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) - - # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution - # for small time offsets but less resolution for large time offsets. - compression_length = self.embed_dim**0.5 - # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; - # but it does so more slowly than T for large absolute values of T. - # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which - # is important. - x_compressed = ( - compression_length - * x.sign() - * ((x.abs() + compression_length).log() - math.log(compression_length)) - ) - - # if self.length_factor == 1.0, then length_scale is chosen so that the - # FFT can exactly separate points close to the origin (T == 0). So this - # part of the formulation is not really heuristic. - # But empirically, for ASR at least, length_factor > 1.0 seems to work better. - length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) - - # note for machine implementations: if atan is not available, we can use: - # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) - # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) - x_atan = (x_compressed / length_scale).atan() # results between -pi and pi - - cosines = (x_atan * freqs).cos() - sines = (x_atan * freqs).sin() - - pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) - pe[:, 0::2] = cosines - pe[:, 1::2] = sines - pe[:, -1] = 1.0 # for bias. - - self.pe = pe.to(dtype=x.dtype) - - def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: - """Create positional encoding. - - Args: - x (Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. - - Returns: - positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). - """ - self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - # length of positive side: x.size(0) + left_context_len - # length of negative side: x.size(0) - pos_emb = self.pe[ - self.pe.size(0) // 2 - - x_size_left - + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0), - :, - ] - pos_emb = pos_emb.unsqueeze(0) - return self.dropout(pos_emb) - - -class RelPositionMultiheadAttentionWeights(nn.Module): - r"""Module that computes multi-head attention weights with relative position encoding. - Various other modules consume the resulting attention weights: see, for example, the - SimpleAttention module which allows you to compute conventional attention. - - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. - - - Args: - embed_dim: number of channels at the input to this module, e.g. 256 - pos_dim: dimension of the positional encoding vectors, e.g. 128. - num_heads: number of heads to compute weights for, e.g. 8 - query_head_dim: dimension of the query (and key), per head. e.g. 24. - pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. - dropout: dropout probability for attn_output_weights. Default: 0.0. - pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on - any given call to forward(), in training time. - """ - - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - dropout: float = 0.0, - pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), - ) -> None: - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.query_head_dim = query_head_dim - self.pos_head_dim = pos_head_dim - self.dropout = dropout - self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) - self.name = None # will be overwritten in training code; for diagnostics. - - key_head_dim = query_head_dim - in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5 that has been used in previous forms of attention, - # dividing it between the query and key. Note: this module is intended - # to be used with the ScaledAdam optimizer; with most other optimizers, - # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear( - embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 - ) - - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=_whitening_schedule(3.0), - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - # add a balancer for the keys that runs with very small probability, and - # tries to enforce that all dimensions have mean around zero. The - # weights produced by this module are invariant to adding a constant to - # the keys, so the derivative of the bias is mathematically zero; but - # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero - # bias because the small numerical roundoff tends to have a non-random - # sign. This module is intended to prevent that. Use a very small - # probability; that should be sufficient to fix the problem. - self.balance_keys = Balancer( - key_head_dim * num_heads, - channel_dim=-1, - min_positive=0.4, - max_positive=0.6, - min_abs=0.0, - max_abs=100.0, - prob=0.025, - ) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear( - pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 - ) - - # the following are for diagnostics only, see --print-diagnostics option - self.copy_pos_query = Identity() - self.copy_query = Identity() - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tensor: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), - interpreted as ([batch_size,] tgt_seq_len, src_seq_len) - saying which positions are allowed to attend to which other positions. - Returns: - a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[..., 0:query_dim] - k = x[..., query_dim : 2 * query_dim] - # p is the position-encoding query - p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim, ( - p.shape[-1], - num_heads, - pos_head_dim, - ) - - q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. - p = self.copy_pos_query(p) # for diagnostics only, does nothing. - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) - k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - use_pos_scores = False - if torch.jit.is_scripting() or torch.jit.is_tracing(): - # We can't put random.random() in the same line - use_pos_scores = True - elif not self.training or random.random() >= float(self.pos_emb_skip_rate): - use_pos_scores = True - - if use_pos_scores: - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( - 2, 0, 3, 1 - ) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - if torch.jit.is_tracing(): - (num_heads, batch_size, time1, n) = pos_scores.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(seq_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_scores = pos_scores.reshape(-1, n) - pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) - else: - pos_scores = pos_scores.as_strided( - (num_heads, batch_size, seq_len, seq_len), - ( - pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2) - pos_scores.stride(3), - pos_scores.stride(3), - ), - storage_offset=pos_scores.stride(3) * (seq_len - 1), - ) - - attn_scores = attn_scores + pos_scores - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif self.training and random.random() < 0.1: - # This is a harder way of limiting the attention scores to not be - # too large. It incurs a penalty if any of them has an absolute - # value greater than 50.0. this should be outside the normal range - # of the attention scores. We use this mechanism instead of, say, - # something added to the loss function involving the entropy, - # because once the entropy gets very small gradients through the - # softmax can become very small, and we'd get zero derivatives. The - # choices of 1.0e-04 as the scale on the penalty makes this - # mechanism vulnerable to the absolute scale of the loss function, - # but we view this as a failsafe to avoid "implausible" parameter - # values rather than a regularization method that should be active - # under normal circumstances. - attn_scores = penalize_abs_values_gt( - attn_scores, limit=25.0, penalty=1.0e-04, name=self.name - ) - - assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) - - if attn_mask is not None: - assert attn_mask.dtype == torch.bool - # use -1000 to avoid nan's where attn_mask and key_padding_mask make - # all scores zero. It's important that this be large enough that exp(-1000) - # is exactly zero, for reasons related to const_attention_rate, it - # compares the final weights with zero. - attn_scores = attn_scores.masked_fill(attn_mask, -1000) - - if key_padding_mask is not None: - assert key_padding_mask.shape == ( - batch_size, - seq_len, - ), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - # We use our own version of softmax, defined in scaling.py, which should - # save a little of the memory used in backprop by, if we are in - # automatic mixed precision mode (amp / autocast), by only storing the - # half-precision output for backprop purposes. - attn_weights = softmax(attn_scores, dim=-1) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif random.random() < 0.001 and not self.training: - self._print_attn_entropy(attn_weights) - - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - return attn_weights - - def _print_attn_entropy(self, attn_weights: Tensor): - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape - - with torch.no_grad(): - with torch.amp.autocast("cuda", enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .mean(dim=(1, 2)) - ) - logging.debug( - f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" - ) - - -class SelfAttention(nn.Module): - """ - The simplest possible attention module. This one works with already-computed attention - weights, e.g. as computed by RelPositionMultiheadAttentionWeights. - - Args: - embed_dim: the input and output embedding dimension - num_heads: the number of attention heads - value_head_dim: the value dimension per head - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - value_head_dim: int, - ) -> None: - super().__init__() - self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) - - self.out_proj = ScaledLinear( - num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 - ) - - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - Returns: - a tensor with the same shape as x. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - - x = ( - x.permute(2, 1, 0, 3) - .contiguous() - .view(seq_len, batch_size, num_heads * value_head_dim) - ) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - x = self.whiten(x) - - return x - - -class FeedforwardModule(nn.Module): - """Feedforward module in TTSZipformer model.""" - - def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): - super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(embed_dim, feedforward_dim) - - self.hidden_balancer = Balancer( - feedforward_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0, - ) - - # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ActivationDropoutAndLinear( - feedforward_dim, - embed_dim, - activation="SwooshL", - dropout_p=dropout, - dropout_shared_dim=0, - bias=True, - initial_scale=0.1, - ) - - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward(self, x: Tensor): - x = self.in_proj(x) - x = self.hidden_balancer(x) - # out_proj contains SwooshL activation, then dropout, then linear. - x = self.out_proj(x) - x = self.out_whiten(x) - return x - - -class NonlinAttention(nn.Module): - """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed - from the attention module) in place of actual convolution. We also took out the second nonlinearity, the - one after the attention mechanism. - - Args: - channels (int): The number of channels of conv layers. - """ - - def __init__( - self, - channels: int, - hidden_channels: int, - ) -> None: - super().__init__() - - self.hidden_channels = hidden_channels - - self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) - - # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, - # because we noticed that well-trained instances of this module have abs-value before the sigmoid - # starting from about 3, and poorly-trained instances of the module have smaller abs values - # before the sigmoid. - self.balancer = Balancer( - hidden_channels, - channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), - max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), - min_abs=0.5, - max_abs=5.0, - ) - self.tanh = nn.Tanh() - - self.identity1 = Identity() # for diagnostics. - self.identity2 = Identity() # for diagnostics. - self.identity3 = Identity() # for diagnostics. - - self.out_proj = ScaledLinear( - hidden_channels, channels, bias=True, initial_scale=0.05 - ) - - self.whiten1 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.whiten2 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - Returns: - a Tensor with the same shape as x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=2) - - # s will go through tanh. - - s = self.balancer(s) - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = self.whiten1(x) - x = x * s - x = self.identity1(x) # diagnostics only, it's the identity. - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = torch.matmul(attn_weights, x) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - y = self.identity2(y) - x = x * y - x = self.identity3(x) - - x = self.out_proj(x) - x = self.whiten2(x) - return x - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer2 model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__( - self, - channels: int, - kernel_size: int, - ) -> None: - """Construct a ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - bottleneck_dim = channels - - self.in_proj = nn.Linear( - channels, - 2 * bottleneck_dim, - ) - # the gradients on in_proj are a little noisy, likely to do with the - # sigmoid in glu. - - # after in_proj we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.balancer1 = Balancer( - bottleneck_dim, - channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), - max_positive=1.0, - min_abs=1.5, - max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), - ) - - self.activation1 = Identity() # for diagnostics - - self.sigmoid = nn.Sigmoid() - - self.activation2 = Identity() # for diagnostics - - assert kernel_size % 2 == 1 - - self.depthwise_conv = nn.Conv1d( - in_channels=bottleneck_dim, - out_channels=bottleneck_dim, - groups=bottleneck_dim, - kernel_size=kernel_size, - padding=kernel_size // 2, - ) - - self.balancer2 = Balancer( - bottleneck_dim, - channel_dim=1, - min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), - max_positive=1.0, - min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), - max_abs=10.0, - ) - - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.out_proj = ActivationDropoutAndLinear( - bottleneck_dim, - channels, - activation="SwooshR", - dropout_p=0.0, - initial_scale=0.05, - ) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=2) - s = self.balancer1(s) - s = self.sigmoid(s) - x = self.activation1(x) # identity. - x = x * s - x = self.activation2(x) # identity - - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - x = self.depthwise_conv(x) - - x = self.balancer2(x) - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.whiten(x) # (time, batch, channels) - x = self.out_proj(x) # (time, batch, channels) - - return x diff --git a/egs/zipvoice/zipvoice/zipvoice_infer.py b/egs/zipvoice/zipvoice/zipvoice_infer.py deleted file mode 100644 index 16cf8e039..000000000 --- a/egs/zipvoice/zipvoice/zipvoice_infer.py +++ /dev/null @@ -1,645 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Xiaomi Corp. (authors: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script generates speech with our pre-trained ZipVoice or - ZipVoice-Distill models. Required models will be automatically - downloaded from HuggingFace. - -Usage: - -Note: If you having trouble connecting to HuggingFace, - try switching endpoint to mirror site: - -export HF_ENDPOINT=https://hf-mirror.com - -(1) Inference of a single sentence: - -python3 zipvoice/zipvoice_infer.py \ - --model-name "zipvoice_distill" \ - --prompt-wav prompt.wav \ - --prompt-text "I am a prompt." \ - --text "I am a sentence." \ - --res-wav-path result.wav - -(2) Inference of a list of sentences: -python3 zipvoice/zipvoice_infer.py \ - --model-name "zipvoice-distill" \ - --test-list test.tsv \ - --res-dir results - -`--model-name` can be `zipvoice` or `zipvoice_distill`, - which are the models before and after distillation, respectively. - -Each line of `test.tsv` is in the format of - `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`. -""" - -import argparse -import datetime as dt -import os - -import numpy as np -import safetensors.torch -import torch -import torch.nn as nn -import torchaudio -from feature import TorchAudioFbank, TorchAudioFbankConfig -from huggingface_hub import hf_hub_download -from lhotse.utils import fix_random_seed -from model import get_distill_model, get_model -from tokenizer import TokenizerEmilia -from utils import AttributeDict -from vocos import Vocos - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--model-name", - type=str, - default="zipvoice_distill", - choices=["zipvoice", "zipvoice_distill"], - help="The model used for inference", - ) - - parser.add_argument( - "--test-list", - type=str, - default=None, - help="The list of prompt speech, prompt_transcription, " - "and text to synthesizein the format of " - "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.", - ) - - parser.add_argument( - "--prompt-wav", - type=str, - default=None, - help="The prompt wav to mimic", - ) - - parser.add_argument( - "--prompt-text", - type=str, - default=None, - help="The transcription of the prompt wav", - ) - - parser.add_argument( - "--text", - type=str, - default=None, - help="The text to synthesize", - ) - - parser.add_argument( - "--res-dir", - type=str, - default="results", - help=""" - Path name of the generated wavs dir, - used when test-list is not None - """, - ) - - parser.add_argument( - "--res-wav-path", - type=str, - default="result.wav", - help=""" - Path name of the generated wav path, - used when test-list is None - """, - ) - - parser.add_argument( - "--guidance-scale", - type=float, - default=None, - help="The scale of classifier-free guidance during inference.", - ) - - parser.add_argument( - "--num-step", - type=int, - default=None, - help="The number of sampling steps.", - ) - - parser.add_argument( - "--feat-scale", - type=float, - default=0.1, - help="The scale factor of fbank feature", - ) - - parser.add_argument( - "--speed", - type=float, - default=1.0, - help="Control speech speed, 1.0 means normal, >1.0 means speed up", - ) - - parser.add_argument( - "--t-shift", - type=float, - default=0.5, - help="Shift t to smaller ones if t_shift < 1.0", - ) - - parser.add_argument( - "--target-rms", - type=float, - default=0.1, - help="Target speech normalization rms value", - ) - - parser.add_argument( - "--seed", - type=int, - default=666, - help="Random seed", - ) - - add_model_arguments(parser) - - return parser - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--fm-decoder-downsampling-factor", - type=str, - default="1,2,4,2,1", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--fm-decoder-num-layers", - type=str, - default="2,2,4,4,4", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--fm-decoder-cnn-module-kernel", - type=str, - default="31,15,7,15,31", - help="Sizes of convolutional kernels in convolution modules " - "in each encoder stack: a single int or comma-separated list.", - ) - - parser.add_argument( - "--fm-decoder-feedforward-dim", - type=int, - default=1536, - help="Feedforward dimension of the zipformer encoder layers, " - "per stack, comma separated.", - ) - - parser.add_argument( - "--fm-decoder-num-heads", - type=int, - default=4, - help="Number of attention heads in the zipformer encoder layers: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--fm-decoder-dim", - type=int, - default=512, - help="Embedding dimension in encoder stacks: a single int " - "or comma-separated list.", - ) - - parser.add_argument( - "--text-encoder-downsampling-factor", - type=str, - default="1", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--text-encoder-num-layers", - type=str, - default="4", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--text-encoder-feedforward-dim", - type=int, - default=512, - help="Feedforward dimension of the zipformer encoder layers, " - "per stack, comma separated.", - ) - - parser.add_argument( - "--text-encoder-cnn-module-kernel", - type=str, - default="9", - help="Sizes of convolutional kernels in convolution modules in " - "each encoder stack: a single int or comma-separated list.", - ) - - parser.add_argument( - "--text-encoder-num-heads", - type=int, - default=4, - help="Number of attention heads in the zipformer encoder layers: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--text-encoder-dim", - type=int, - default=192, - help="Embedding dimension in encoder stacks: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=int, - default=32, - help="Query/key dimension per head in encoder stacks: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=int, - default=12, - help="Value dimension per head in encoder stacks: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=int, - default=4, - help="Positional-encoding dimension per head in encoder stacks: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default=48, - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--time-embed-dim", - type=int, - default=192, - help="Embedding dimension of timestamps embedding.", - ) - - parser.add_argument( - "--text-embed-dim", - type=int, - default=192, - help="Embedding dimension of text embedding.", - ) - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "sampling_rate": 24000, - "frame_shift_ms": 256 / 24000 * 1000, - "feat_dim": 100, - } - ) - - return params - - -def get_vocoder(): - vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz") - return vocoder - - -def generate_sentence( - save_path: str, - prompt_text: str, - prompt_wav: str, - text: str, - model: nn.Module, - vocoder: nn.Module, - tokenizer: TokenizerEmilia, - feature_extractor: TorchAudioFbank, - device: torch.device, - num_step: int = 16, - guidance_scale: float = 1.0, - speed: float = 1.0, - t_shift: float = 0.5, - target_rms: float = 0.1, - feat_scale: float = 0.1, - sampling_rate: int = 24000, -): - """ - Generate waveform of a text based on a given prompt - waveform and its transcription. - - Args: - save_path (str): Path to save the generated wav. - prompt_text (str): Transcription of the prompt wav. - prompt_wav (str): Path to the prompt wav file. - text (str): Text to be synthesized into a waveform. - model (nn.Module): The model used for generation. - vocoder (nn.Module): The vocoder used to convert features to waveforms. - tokenizer (TokenizerEmilia): The tokenizer used to convert text to tokens. - feature_extractor (TorchAudioFbank): The feature extractor used to - extract acoustic features. - device (torch.device): The device on which computations are performed. - num_step (int, optional): Number of steps for decoding. Defaults to 16. - guidance_scale (float, optional): Scale for classifier-free guidance. - Defaults to 1.0. - speed (float, optional): Speed control. Defaults to 1.0. - t_shift (float, optional): Time shift. Defaults to 0.5. - target_rms (float, optional): Target RMS for waveform normalization. - Defaults to 0.1. - feat_scale (float, optional): Scale for features. - Defaults to 0.1. - sampling_rate (int, optional): Sampling rate for the waveform. - Defaults to 24000. - Returns: - metrics (dict): Dictionary containing time and real-time - factor metrics for processing. - """ - # Convert text to tokens - tokens = tokenizer.texts_to_token_ids([text]) - prompt_tokens = tokenizer.texts_to_token_ids([prompt_text]) - - # Load and preprocess prompt wav - prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav) - prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav))) - if prompt_rms < target_rms: - prompt_wav = prompt_wav * target_rms / prompt_rms - - if prompt_sampling_rate != sampling_rate: - resampler = torchaudio.transforms.Resample( - orig_freq=prompt_sampling_rate, new_freq=sampling_rate - ) - prompt_wav = resampler(prompt_wav) - - # Extract features from prompt wav - prompt_features = feature_extractor.extract( - prompt_wav, sampling_rate=sampling_rate - ).to(device) - prompt_features = prompt_features.unsqueeze(0) * feat_scale - prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device) - - # Start timing - start_t = dt.datetime.now() - - # Generate features - ( - pred_features, - pred_features_lens, - pred_prompt_features, - pred_prompt_features_lens, - ) = model.sample( - tokens=tokens, - prompt_tokens=prompt_tokens, - prompt_features=prompt_features, - prompt_features_lens=prompt_features_lens, - speed=speed, - t_shift=t_shift, - duration="predict", - num_step=num_step, - guidance_scale=guidance_scale, - ) - - # Postprocess predicted features - pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T) - - # Start vocoder processing - start_vocoder_t = dt.datetime.now() - wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1) - - # Calculate processing times and real-time factors - t = (dt.datetime.now() - start_t).total_seconds() - t_no_vocoder = (start_vocoder_t - start_t).total_seconds() - t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds() - wav_seconds = wav.shape[-1] / sampling_rate - rtf = t / wav_seconds - rtf_no_vocoder = t_no_vocoder / wav_seconds - rtf_vocoder = t_vocoder / wav_seconds - metrics = { - "t": t, - "t_no_vocoder": t_no_vocoder, - "t_vocoder": t_vocoder, - "wav_seconds": wav_seconds, - "rtf": rtf, - "rtf_no_vocoder": rtf_no_vocoder, - "rtf_vocoder": rtf_vocoder, - } - - # Adjust wav volume if necessary - if prompt_rms < target_rms: - wav = wav * prompt_rms / target_rms - torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate) - - return metrics - - -def generate( - res_dir: str, - test_list: str, - model: nn.Module, - vocoder: nn.Module, - tokenizer: TokenizerEmilia, - feature_extractor: TorchAudioFbank, - device: torch.device, - num_step: int = 16, - guidance_scale: float = 1.0, - speed: float = 1.0, - t_shift: float = 0.5, - target_rms: float = 0.1, - feat_scale: float = 0.1, - sampling_rate: int = 24000, -): - total_t = [] - total_t_no_vocoder = [] - total_t_vocoder = [] - total_wav_seconds = [] - - with open(test_list, "r") as fr: - lines = fr.readlines() - - for i, line in enumerate(lines): - wav_name, prompt_text, prompt_wav, text = line.strip().split("\t") - save_path = f"{res_dir}/{wav_name}.wav" - metrics = generate_sentence( - save_path=save_path, - prompt_text=prompt_text, - prompt_wav=prompt_wav, - text=text, - model=model, - vocoder=vocoder, - tokenizer=tokenizer, - feature_extractor=feature_extractor, - device=device, - num_step=num_step, - guidance_scale=guidance_scale, - speed=speed, - t_shift=t_shift, - target_rms=target_rms, - feat_scale=feat_scale, - sampling_rate=sampling_rate, - ) - print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}") - total_t.append(metrics["t"]) - total_t_no_vocoder.append(metrics["t_no_vocoder"]) - total_t_vocoder.append(metrics["t_vocoder"]) - total_wav_seconds.append(metrics["wav_seconds"]) - - print(f"Average RTF: {np.sum(total_t)/np.sum(total_wav_seconds):.4f}") - print( - f"Average RTF w/o vocoder: " - f"{np.sum(total_t_no_vocoder)/np.sum(total_wav_seconds):.4f}" - ) - print( - f"Average RTF vocoder: " - f"{np.sum(total_t_vocoder)/np.sum(total_wav_seconds):.4f}" - ) - - -@torch.inference_mode() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - - model_defaults = { - "zipvoice": { - "num_step": 16, - "guidance_scale": 1.0, - }, - "zipvoice_distill": { - "num_step": 8, - "guidance_scale": 3.0, - }, - } - - model_specific_defaults = model_defaults.get(params.model_name, {}) - - for param, value in model_specific_defaults.items(): - if getattr(params, param) == parser.get_default(param): - setattr(params, param, value) - print(f"Setting {param} to default value: {value}") - - assert (params.test_list is not None) ^ ( - (params.prompt_wav and params.prompt_text and params.text) is not None - ), ( - "For inference, please provide prompts and text with either '--test-list'" - " or '--prompt-wav, --prompt-text and --text'." - ) - - if torch.cuda.is_available(): - params.device = torch.device("cuda", 0) - else: - params.device = torch.device("cpu") - - token_file = hf_hub_download("zhu-han/ZipVoice", filename="tokens_emilia.txt") - - tokenizer = TokenizerEmilia(token_file) - - params.vocab_size = tokenizer.vocab_size - params.pad_id = tokenizer.pad_id - fix_random_seed(params.seed) - - if params.model_name == "zipvoice_distill": - model = get_distill_model(params) - model_ckpt = hf_hub_download( - "zhu-han/ZipVoice", filename="exp_zipvoice_distill/model.safetensors" - ) - else: - model = get_model(params) - model_ckpt = hf_hub_download( - "zhu-han/ZipVoice", filename="exp_zipvoice/model.safetensors" - ) - - safetensors.torch.load_model(model, model_ckpt) - - model = model.to(params.device) - model.eval() - - vocoder = get_vocoder() - vocoder = vocoder.to(params.device) - vocoder.eval() - - config = TorchAudioFbankConfig( - sampling_rate=params.sampling_rate, - n_mels=100, - n_fft=1024, - hop_length=256, - ) - feature_extractor = TorchAudioFbank(config) - - if params.test_list: - os.makedirs(params.res_dir, exist_ok=True) - generate( - res_dir=params.res_dir, - test_list=params.test_list, - model=model, - vocoder=vocoder, - tokenizer=tokenizer, - feature_extractor=feature_extractor, - device=params.device, - num_step=params.num_step, - guidance_scale=params.guidance_scale, - speed=params.speed, - t_shift=params.t_shift, - target_rms=params.target_rms, - feat_scale=params.feat_scale, - sampling_rate=params.sampling_rate, - ) - else: - generate_sentence( - save_path=params.res_wav_path, - prompt_text=params.prompt_text, - prompt_wav=params.prompt_wav, - text=params.text, - model=model, - vocoder=vocoder, - tokenizer=tokenizer, - feature_extractor=feature_extractor, - device=params.device, - num_step=params.num_step, - guidance_scale=params.guidance_scale, - speed=params.speed, - t_shift=params.t_shift, - target_rms=params.target_rms, - feat_scale=params.feat_scale, - sampling_rate=params.sampling_rate, - ) - print("Done") - - -if __name__ == "__main__": - main()