diff --git a/egs/zipvoice/README.md b/egs/zipvoice/README.md
deleted file mode 100644
index 0c97d7ed8..000000000
--- a/egs/zipvoice/README.md
+++ /dev/null
@@ -1,412 +0,0 @@
-## ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching
-
-
-[](http://arxiv.org/abs/2506.13053)
-[](https://zipvoice.github.io/)
-
-
-## Overview
-ZipVoice is a high-quality zero-shot TTS model with a small model size and fast inference speed.
-#### Key features:
-
-- Small and fast: only 123M parameters.
-
-- High-quality: state-of-the-art voice cloning performance in speaker similarity, intelligibility, and naturalness.
-
-- Multi-lingual: support Chinese and English.
-
-
-## News
-**2025/06/16**: 🔥 ZipVoice is released.
-
-
-## Installation
-
-* Clone icefall repository and change to zipvoice directory:
-
-```bash
-git clone https://github.com/k2-fsa/icefall.git
-cd icefall/egs/zipvoice
-```
-
-* Create a Python virtual environment (optional but recommended):
-
-```bash
-python3 -m venv venv
-source venv/bin/activate
-```
-
-* Install the required packages:
-
-```bash
-pip install -r requirements.txt
-```
-
-## Usage
-
-To generate speech with our pre-trained ZipVoice or ZipVoice-Distill models, use the following commands (Required models will be downloaded from HuggingFace):
-
-### 1. Inference of a single sentence:
-```bash
-python3 zipvoice/zipvoice_infer.py \
- --model-name "zipvoice_distill" \
- --prompt-wav prompt.wav \
- --prompt-text "I am the transcription of the prompt wav." \
- --text "I am the text to be synthesized." \
- --res-wav-path result.wav
-
-# Example with a pre-defined prompt wav and text
-python3 zipvoice/zipvoice_infer.py \
- --model-name "zipvoice_distill" \
- --prompt-wav assets/prompt-en.wav \
- --prompt-text "Some call me nature, others call me mother nature. I've been here for over four point five billion years, twenty two thousand five hundred times longer than you." \
- --text "Welcome to use our tts model, have fun!" \
- --res-wav-path result.wav
-```
-
-### 2. Inference of a list of sentences:
-```bash
-python3 zipvoice/zipvoice_infer.py \
- --model-name "zipvoice_distill" \
- --test-list test.tsv \
- --res-dir results/test
-```
-
-- `--model-name` can be `zipvoice` or `zipvoice_distill`, which are models before and after distillation, respectively.
-- Each line of `test.tsv` is in the format of `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
-
-
-> **Note:** If you having trouble connecting to HuggingFace, try:
-```bash
-export HF_ENDPOINT=https://hf-mirror.com
-```
-
-## Training Your Own Model
-
-The following steps show how to train a model from scratch on Emilia and LibriTTS datasets, respectively.
-
-### 0. Install dependencies for training
-
-```bash
-# Install pytorch and k2.
-# If you want to use different versions, please refer to https://k2-fsa.org/get-started/k2/ for details.
-# For users in China mainland, please refer to https://k2-fsa.org/zh-CN/get-started/k2/
-
-# Note: Make sure you have installed the correct version of PyTorch and k2 that matches your CUDA version.
-# For example, if want to use pytorch 2.5.1 and you are using CUDA 12.1, you can install PyTorch and k2 as follows:
-
-pip install torch==2.5.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
-pip install k2==1.24.4.dev20250208+cuda12.1.torch2.5.1 -f https://k2-fsa.github.io/k2/cuda.html
-
-pip install -r ../../requirements.txt
-```
-
-### 1. Data Preparation
-
-#### 1.1. Prepare the Emilia dataset
-
-```bash
-bash scripts/prepare_emilia.sh
-```
-
-See [scripts/prepare_emilia.sh](scripts/prepare_emilia.sh) for step by step instructions.
-
-#### 1.2 Prepare the LibriTTS dataset
-
-```bash
-bash scripts/prepare_libritts.sh
-```
-
-See [scripts/prepare_libritts.sh](scripts/prepare_libritts.sh) for step by step instructions.
-
-### 2. Training
-
-#### 2.1 Traininig on Emilia
-
-
-Expand to view training steps
-
-##### 2.1.1 Train the ZipVoice model
-
-- Training:
-
-```bash
-export PYTHONPATH=../../:$PYTHONPATH
-python3 zipvoice/train_flow.py \
- --world-size 8 \
- --use-fp16 1 \
- --dataset emilia \
- --max-duration 500 \
- --lr-hours 30000 \
- --lr-batches 7500 \
- --token-file "data/tokens_emilia.txt" \
- --manifest-dir "data/fbank" \
- --num-epochs 11 \
- --exp-dir zipvoice/exp_zipvoice
-```
-
-- Average the checkpoints to produce the final model:
-
-```bash
-export PYTHONPATH=../../:$PYTHONPATH
-python3 zipvoice/generate_averaged_model.py \
- --epoch 11 \
- --avg 4 \
- --distill 0 \
- --token-file data/tokens_emilia.txt \
- --dataset "emilia" \
- --exp-dir ./zipvoice/exp_zipvoice
-# The generated model is zipvoice/exp_zipvoice/epoch-11-avg-4.pt
-```
-
-##### 2.1.2. Train the ZipVoice-Distill model (Optional)
-
-- The first-stage distillation:
-
-```bash
-export PYTHONPATH=../../:$PYTHONPATH
-python3 zipvoice/train_distill.py \
- --world-size 8 \
- --use-fp16 1 \
- --tensorboard 1 \
- --dataset "emilia" \
- --base-lr 0.0005 \
- --max-duration 500 \
- --token-file "data/tokens_emilia.txt" \
- --manifest-dir "data/fbank" \
- --teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
- --num-updates 60000 \
- --distill-stage "first" \
- --exp-dir zipvoice/exp_zipvoice_distill_1stage
-```
-
-- Average checkpoints for the second-stage initialization:
-
-```bash
-export PYTHONPATH=../../:$PYTHONPATH
-python3 zipvoice/generate_averaged_model.py \
- --iter 60000 \
- --avg 7 \
- --distill 1 \
- --token-file data/tokens_emilia.txt \
- --dataset "emilia" \
- --exp-dir ./zipvoice/exp_zipvoice_distill_1stage
-# The generated model is zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt
-```
-
-- The second-stage distillation:
-
-```bash
-export PYTHONPATH=../../:$PYTHONPATH
-python3 zipvoice/train_distill.py \
- --world-size 8 \
- --use-fp16 1 \
- --tensorboard 1 \
- --dataset "emilia" \
- --base-lr 0.0001 \
- --max-duration 200 \
- --token-file "data/tokens_emilia.txt" \
- --manifest-dir "data/fbank" \
- --teacher-model zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt \
- --num-updates 2000 \
- --distill-stage "second" \
- --exp-dir zipvoice/exp_zipvoice_distill_new
-```
-
-
-
-#### 2.2 Traininig on LibriTTS
-
-
-Expand to view training steps
-
-##### 2.2.1 Train the ZipVoice model
-
-- Training:
-
-```bash
-export PYTHONPATH=../../:$PYTHONPATH
-python3 zipvoice/train_flow.py \
- --world-size 8 \
- --use-fp16 1 \
- --dataset libritts \
- --max-duration 250 \
- --lr-epochs 10 \
- --lr-batches 7500 \
- --token-file "data/tokens_libritts.txt" \
- --manifest-dir "data/fbank" \
- --num-epochs 60 \
- --exp-dir zipvoice/exp_zipvoice_libritts
-```
-
-- Average the checkpoints to produce the final model:
-
-```bash
-export PYTHONPATH=../../:$PYTHONPATH
-python3 zipvoice/generate_averaged_model.py \
- --epoch 60 \
- --avg 10 \
- --distill 0 \
- --token-file data/tokens_libritts.txt \
- --dataset "libritts" \
- --exp-dir ./zipvoice/exp_zipvoice_libritts
-# The generated model is zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt
-```
-
-##### 2.1.2 Train the ZipVoice-Distill model (Optional)
-
-- The first-stage distillation:
-
-```bash
-export PYTHONPATH=../../:$PYTHONPATH
-python3 zipvoice/train_distill.py \
- --world-size 8 \
- --use-fp16 1 \
- --tensorboard 1 \
- --dataset "libritts" \
- --base-lr 0.001 \
- --max-duration 250 \
- --token-file "data/tokens_libritts.txt" \
- --manifest-dir "data/fbank" \
- --teacher-model zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt \
- --num-epochs 6 \
- --distill-stage "first" \
- --exp-dir zipvoice/exp_zipvoice_distill_1stage_libritts
-```
-
-- Average checkpoints for the second-stage initialization:
-
-```bash
-export PYTHONPATH=../../:$PYTHONPATH
-python3 ./zipvoice/generate_averaged_model.py \
- --epoch 6 \
- --avg 3 \
- --distill 1 \
- --token-file data/tokens_libritts.txt \
- --dataset "libritts" \
- --exp-dir ./zipvoice/exp_zipvoice_distill_1stage_libritts
-# The generated model is zipvoice/exp_zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt
-```
-
-- The second-stage distillation:
-
-```bash
-export PYTHONPATH=../../:$PYTHONPATH
-python3 zipvoice/train_distill.py \
- --world-size 8 \
- --use-fp16 1 \
- --tensorboard 1 \
- --dataset "libritts" \
- --base-lr 0.001 \
- --max-duration 250 \
- --token-file "data/tokens_libritts.txt" \
- --manifest-dir "data/fbank" \
- --teacher-model zipvoice/exp_zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt \
- --num-epochs 6 \
- --distill-stage "second" \
- --exp-dir zipvoice/exp_zipvoice_distill_libritts
-```
-
-- Average checkpoints to produce the final model:
-
-```bash
-export PYTHONPATH=../../:$PYTHONPATH
-python3 ./zipvoice/generate_averaged_model.py \
- --epoch 6 \
- --avg 3 \
- --distill 1 \
- --token-file data/tokens_libritts.txt \
- --dataset "libritts" \
- --exp-dir ./zipvoice/exp_zipvoice_distill_libritts
-# The generated model is ./zipvoice/exp_zipvoice_distill_libritts/epoch-6-avg-3.pt
-```
-
-
-
-### 3. Inference with the trained model
-
-#### 3.1 Inference with the model trained on Emilia
-
-Expand to view inference commands.
-
-##### 3.1.1 ZipVoice model before distill:
-```bash
-export PYTHONPATH=../../:$PYTHONPATH
-python3 zipvoice/infer.py \
- --checkpoint zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
- --distill 0 \
- --token-file "data/tokens_emilia.txt" \
- --test-list test.tsv \
- --res-dir results/test \
- --num-step 16 \
- --guidance-scale 1
-```
-
-##### 3.1.2 ZipVoice-Distill model before distill:
-```bash
-export PYTHONPATH=../../:$PYTHONPATH
-python3 zipvoice/infer.py \
- --checkpoint zipvoice/exp_zipvoice_distill/checkpoint-2000.pt \
- --distill 1 \
- --token-file "data/tokens_emilia.txt" \
- --test-list test.tsv \
- --res-dir results/test_distill \
- --num-step 8 \
- --guidance-scale 3
-```
-
-
-
-#### 3.2 Inference with the model trained on LibriTTS
-
-
-Expand to view inference commands.
-
-##### 3.2.1 ZipVoice model before distill:
-```bash
-export PYTHONPATH=../../:$PYTHONPATH
-python3 zipvoice/infer.py \
- --checkpoint zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt \
- --distill 0 \
- --token-file "data/tokens_libritts.txt" \
- --test-list test.tsv \
- --res-dir results/test_libritts \
- --num-step 8 \
- --guidance-scale 1 \
- --target-rms 1.0 \
- --t-shift 0.7
-```
-
-##### 3.2.2 ZipVoice-Distill model before distill
-
-```bash
-export PYTHONPATH=../../:$PYTHONPATH
-python3 zipvoice/infer.py \
- --checkpoint zipvoice/exp_zipvoice_distill/epoch-6-avg-3.pt \
- --distill 1 \
- --token-file "data/tokens_libritts.txt" \
- --test-list test.tsv \
- --res-dir results/test_distill_libritts \
- --num-step 4 \
- --guidance-scale 3 \
- --target-rms 1.0 \
- --t-shift 0.7
-```
-
-
-### 4. Evaluation on benchmarks
-
-See [local/evaluate.sh](local/evaluate.sh) for details of objective metrics evaluation
-on three test sets, i.e., LibriSpeech-PC test-clean, Seed-TTS test-en and Seed-TTS test-zh.
-
-
-## Citation
-
-```bibtex
-@article{zhu-2025-zipvoice,
- title={ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching},
- author={Han Zhu and Wei Kang and Zengwei Yao and Liyong Guo and Fangjun Kuang and Zhaoqing Li and Weiji Zhuang and Long Lin and Daniel Povey}
- journal={arXiv preprint arXiv:2506.13053},
- year={2025},
-}
-```
diff --git a/egs/zipvoice/assets/prompt-en.wav b/egs/zipvoice/assets/prompt-en.wav
deleted file mode 100644
index b7047ce9b..000000000
Binary files a/egs/zipvoice/assets/prompt-en.wav and /dev/null differ
diff --git a/egs/zipvoice/local/compute_fbank.py b/egs/zipvoice/local/compute_fbank.py
deleted file mode 100644
index 0c440b995..000000000
--- a/egs/zipvoice/local/compute_fbank.py
+++ /dev/null
@@ -1,287 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2025 Xiaomi Corp. (authors: Wei Kang)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import argparse
-import logging
-import os
-from concurrent.futures import ProcessPoolExecutor as Pool
-from pathlib import Path
-from typing import Optional
-
-import lhotse
-import torch
-from feature import TorchAudioFbank, TorchAudioFbankConfig
-from lhotse import (
- CutSet,
- LilcomChunkyWriter,
- load_manifest_lazy,
- set_audio_duration_mismatch_tolerance,
-)
-
-# Torch's multithreaded behavior needs to be disabled or
-# it wastes a lot of CPU and slow things down.
-# Do this outside of main() in case it needs to take effect
-# even when we are not invoking the main (e.g. when spawning subprocesses).
-torch.set_num_threads(1)
-torch.set_num_interop_threads(1)
-
-
-def str2bool(v):
- """Used in argparse.ArgumentParser.add_argument to indicate
- that a type is a bool type and user can enter
-
- - yes, true, t, y, 1, to represent True
- - no, false, f, n, 0, to represent False
-
- See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
- """
- if isinstance(v, bool):
- return v
- if v.lower() in ("yes", "true", "t", "y", "1"):
- return True
- elif v.lower() in ("no", "false", "f", "n", "0"):
- return False
- else:
- raise argparse.ArgumentTypeError("Boolean value expected.")
-
-
-def get_args():
- parser = argparse.ArgumentParser()
-
- parser.add_argument(
- "--sampling-rate",
- type=int,
- default=24000,
- help="The target sampling rate, the audio will be resampled to this sampling_rate.",
- )
-
- parser.add_argument(
- "--frame-shift",
- type=int,
- default=256,
- help="Frame shift in samples",
- )
-
- parser.add_argument(
- "--frame-length",
- type=int,
- default=1024,
- help="Frame length in samples",
- )
-
- parser.add_argument(
- "--num-mel-bins",
- type=int,
- default=100,
- help="The num of mel filters.",
- )
-
- parser.add_argument(
- "--dataset",
- type=str,
- help="Dataset name.",
- )
-
- parser.add_argument(
- "--subset",
- type=str,
- help="The subset of the dataset.",
- )
-
- parser.add_argument(
- "--source-dir",
- type=str,
- default="data/manifests",
- help="The source directory of manifest files.",
- )
-
- parser.add_argument(
- "--dest-dir",
- type=str,
- default="data/fbank",
- help="The destination directory of manifest files.",
- )
-
- parser.add_argument(
- "--split-cuts",
- type=str2bool,
- default=False,
- help="Whether to use splited cuts.",
- )
-
- parser.add_argument(
- "--split-begin",
- type=int,
- help="Start idx of splited cuts.",
- )
-
- parser.add_argument(
- "--split-end",
- type=int,
- help="End idx of splited cuts.",
- )
-
- parser.add_argument(
- "--batch-duration",
- type=int,
- default=1000,
- help="The batch duration when computing the features.",
- )
-
- parser.add_argument(
- "--num-jobs", type=int, default=20, help="The number of extractor workers."
- )
-
- return parser.parse_args()
-
-
-def compute_fbank_split_single(params, idx):
- lhotse.set_audio_duration_mismatch_tolerance(0.1) # for emilia
- src_dir = Path(params.source_dir)
- output_dir = Path(params.dest_dir)
- num_mel_bins = params.num_mel_bins
-
- if not src_dir.exists():
- logging.error(f"{src_dir} not exists")
- return
-
- if not output_dir.exists():
- output_dir.mkdir(parents=True, exist_ok=True)
-
- num_digits = 8
-
- config = TorchAudioFbankConfig(
- sampling_rate=params.sampling_rate,
- n_mels=params.num_mel_bins,
- n_fft=params.frame_length,
- hop_length=params.frame_shift,
- )
- extractor = TorchAudioFbank(config)
-
- prefix = params.dataset
- subset = params.subset
- suffix = "jsonl.gz"
-
- idx = f"{idx}".zfill(num_digits)
- cuts_filename = f"{prefix}_cuts_{subset}.{idx}.{suffix}"
-
- if (src_dir / cuts_filename).is_file():
- logging.info(f"Loading manifests {src_dir / cuts_filename}")
- cut_set = load_manifest_lazy(src_dir / cuts_filename)
- else:
- logging.warning(f"Raw {cuts_filename} not exists, skipping")
- return
-
- cut_set = cut_set.resample(params.sampling_rate)
-
- if (output_dir / cuts_filename).is_file():
- logging.info(f"{cuts_filename} already exists - skipping.")
- return
-
- logging.info(f"Processing {subset}.{idx} of {prefix}")
-
- cut_set = cut_set.compute_and_store_features_batch(
- extractor=extractor,
- storage_path=f"{output_dir}/{prefix}_feats_{subset}_{idx}",
- num_workers=4,
- batch_duration=params.batch_duration,
- storage_type=LilcomChunkyWriter,
- overwrite=True,
- )
- cut_set.to_file(output_dir / cuts_filename)
-
-
-def compute_fbank_split(params):
- if params.split_end < params.split_begin:
- logging.warning(
- f"Split begin should be smaller than split end, given "
- f"{params.split_begin} -> {params.split_end}."
- )
-
- with Pool(max_workers=params.num_jobs) as pool:
- futures = [
- pool.submit(compute_fbank_split_single, params, i)
- for i in range(params.split_begin, params.split_end)
- ]
- for f in futures:
- f.result()
- f.done()
-
-
-def compute_fbank(params):
- src_dir = Path(params.source_dir)
- output_dir = Path(params.dest_dir)
- num_jobs = params.num_jobs
- num_mel_bins = params.num_mel_bins
-
- prefix = params.dataset
- subset = params.subset
- suffix = "jsonl.gz"
-
- cut_set_name = f"{prefix}_cuts_{subset}.{suffix}"
-
- if (src_dir / cut_set_name).is_file():
- logging.info(f"Loading manifests {src_dir / cut_set_name}")
- cut_set = load_manifest_lazy(src_dir / cut_set_name)
- else:
- recordings = load_manifest_lazy(
- src_dir / f"{prefix}_recordings_{subset}.{suffix}"
- )
- supervisions = load_manifest_lazy(
- src_dir / f"{prefix}_supervisions_{subset}.{suffix}"
- )
- cut_set = CutSet.from_manifests(
- recordings=recordings,
- supervisions=supervisions,
- )
-
- cut_set = cut_set.resample(params.sampling_rate)
-
- config = TorchAudioFbankConfig(
- sampling_rate=params.sampling_rate,
- n_mels=params.num_mel_bins,
- n_fft=params.frame_length,
- hop_length=params.frame_shift,
- )
- extractor = TorchAudioFbank(config)
-
- cuts_filename = f"{prefix}_cuts_{subset}.{suffix}"
- if (output_dir / cuts_filename).is_file():
- logging.info(f"{prefix} {subset} already exists - skipping.")
- return
- logging.info(f"Processing {subset} of {prefix}")
-
- cut_set = cut_set.compute_and_store_features(
- extractor=extractor,
- storage_path=f"{output_dir}/{prefix}_feats_{subset}",
- num_jobs=num_jobs,
- storage_type=LilcomChunkyWriter,
- )
- cut_set.to_file(output_dir / cuts_filename)
-
-
-if __name__ == "__main__":
- formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-
- logging.basicConfig(format=formatter, level=logging.INFO)
- args = get_args()
- logging.info(vars(args))
- if args.split_cuts:
- compute_fbank_split(params=args)
- else:
- compute_fbank(params=args)
diff --git a/egs/zipvoice/local/evaluate_sim.py b/egs/zipvoice/local/evaluate_sim.py
deleted file mode 100644
index df439cf2c..000000000
--- a/egs/zipvoice/local/evaluate_sim.py
+++ /dev/null
@@ -1,508 +0,0 @@
-"""
-Calculate pairwise Speaker Similarity betweeen two speech directories.
-SV model wavlm_large_finetune.pth is downloaded from
- https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification
-SSL model wavlm_large.pt is downloaded from
- https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt
-"""
-import argparse
-import logging
-import os
-
-import librosa
-import numpy as np
-import soundfile as sf
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from tqdm import tqdm
-
-logging.basicConfig(level=logging.INFO)
-
-
-def get_parser():
- parser = argparse.ArgumentParser()
-
- parser.add_argument(
- "--eval-path", type=str, help="path of the evaluated speech directory"
- )
- parser.add_argument(
- "--test-list",
- type=str,
- help="path of the file list that contains the corresponding "
- "relationship between the prompt and evaluated speech. "
- "The first column is the wav name and the third column is the prompt speech",
- )
- parser.add_argument(
- "--sv-model-path",
- type=str,
- default="model/UniSpeech/wavlm_large_finetune.pth",
- help="path of the wavlm-based ECAPA-TDNN model",
- )
- parser.add_argument(
- "--ssl-model-path",
- type=str,
- default="model/s3prl/wavlm_large.pt",
- help="path of the wavlm SSL model",
- )
- return parser
-
-
-class SpeakerSimilarity:
- def __init__(
- self,
- sv_model_path="model/UniSpeech/wavlm_large_finetune.pth",
- ssl_model_path="model/s3prl/wavlm_large.pt",
- ):
- """
- Initialize
- """
- self.sample_rate = 16000
- self.channels = 1
- self.device = (
- torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
- )
- logging.info("[Speaker Similarity] Using device: {}".format(self.device))
- self.model = ECAPA_TDNN_WAVLLM(
- feat_dim=1024,
- channels=512,
- emb_dim=256,
- sr=16000,
- ssl_model_path=ssl_model_path,
- )
- state_dict = torch.load(
- sv_model_path, map_location=lambda storage, loc: storage
- )
- self.model.load_state_dict(state_dict["model"], strict=False)
- self.model.to(self.device)
- self.model.eval()
-
- def get_embeddings(self, wav_list, dtype="float32"):
- """
- Get embeddings
- """
-
- def _load_speech_task(fname, sample_rate):
-
- wav_data, sr = sf.read(fname, dtype=dtype)
- if sr != sample_rate:
- wav_data = librosa.resample(
- wav_data, orig_sr=sr, target_sr=self.sample_rate
- )
- wav_data = torch.from_numpy(wav_data)
-
- return wav_data
-
- embd_lst = []
- for file_path in tqdm(wav_list):
- speech = _load_speech_task(file_path, self.sample_rate)
- speech = speech.to(self.device)
- with torch.no_grad():
- embd = self.model([speech])
- embd_lst.append(embd)
-
- return embd_lst
-
- def score(
- self,
- eval_path,
- test_list,
- dtype="float32",
- ):
- """
- Computes the Speaker Similarity (SIM-o) between two directories of speech files.
-
- Parameters:
- - eval_path (str): Path to the directory containing evaluation speech files.
- - test_list (str): Path to the file containing the corresponding relationship
- between prompt and evaluated speech.
- - dtype (str, optional): Data type for loading speech. Default is "float32".
-
- Returns:
- - float: The Speaker Similarity (SIM-o) score between the two directories
- of speech files.
- """
- prompt_wavs = []
- eval_wavs = []
- with open(test_list, "r") as fr:
- lines = fr.readlines()
- for line in lines:
- wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
- prompt_wavs.append(prompt_wav)
- eval_wavs.append(os.path.join(eval_path, wav_name + ".wav"))
- embds_prompt = self.get_embeddings(prompt_wavs, dtype=dtype)
-
- embds_eval = self.get_embeddings(eval_wavs, dtype=dtype)
-
- # Check if embeddings are empty
- if len(embds_prompt) == 0:
- logging.info("[Speaker Similarity] real set dir is empty, exiting...")
- return -1
- if len(embds_eval) == 0:
- logging.info("[Speaker Similarity] eval set dir is empty, exiting...")
- return -1
-
- scores = []
- for real_embd, eval_embd in zip(embds_prompt, embds_eval):
- scores.append(
- torch.nn.functional.cosine_similarity(real_embd, eval_embd, dim=-1)
- .detach()
- .cpu()
- .numpy()
- )
-
- return np.mean(scores)
-
-
-# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
-
-""" Res2Conv1d + BatchNorm1d + ReLU
-"""
-
-
-class Res2Conv1dReluBn(nn.Module):
- """
- in_channels == out_channels == channels
- """
-
- def __init__(
- self,
- channels,
- kernel_size=1,
- stride=1,
- padding=0,
- dilation=1,
- bias=True,
- scale=4,
- ):
- super().__init__()
- assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
- self.scale = scale
- self.width = channels // scale
- self.nums = scale if scale == 1 else scale - 1
-
- self.convs = []
- self.bns = []
- for i in range(self.nums):
- self.convs.append(
- nn.Conv1d(
- self.width,
- self.width,
- kernel_size,
- stride,
- padding,
- dilation,
- bias=bias,
- )
- )
- self.bns.append(nn.BatchNorm1d(self.width))
- self.convs = nn.ModuleList(self.convs)
- self.bns = nn.ModuleList(self.bns)
-
- def forward(self, x):
- out = []
- spx = torch.split(x, self.width, 1)
- for i in range(self.nums):
- if i == 0:
- sp = spx[i]
- else:
- sp = sp + spx[i]
- # Order: conv -> relu -> bn
- sp = self.convs[i](sp)
- sp = self.bns[i](F.relu(sp))
- out.append(sp)
- if self.scale != 1:
- out.append(spx[self.nums])
- out = torch.cat(out, dim=1)
-
- return out
-
-
-""" Conv1d + BatchNorm1d + ReLU
-"""
-
-
-class Conv1dReluBn(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- dilation=1,
- bias=True,
- ):
- super().__init__()
- self.conv = nn.Conv1d(
- in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
- )
- self.bn = nn.BatchNorm1d(out_channels)
-
- def forward(self, x):
- return self.bn(F.relu(self.conv(x)))
-
-
-""" The SE connection of 1D case.
-"""
-
-
-class SE_Connect(nn.Module):
- def __init__(self, channels, se_bottleneck_dim=128):
- super().__init__()
- self.linear1 = nn.Linear(channels, se_bottleneck_dim)
- self.linear2 = nn.Linear(se_bottleneck_dim, channels)
-
- def forward(self, x):
- out = x.mean(dim=2)
- out = F.relu(self.linear1(out))
- out = torch.sigmoid(self.linear2(out))
- out = x * out.unsqueeze(2)
-
- return out
-
-
-""" SE-Res2Block of the ECAPA-TDNN architecture.
-"""
-
-
-# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
-# return nn.Sequential(
-# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
-# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
-# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
-# SE_Connect(channels)
-# )
-
-
-class SE_Res2Block(nn.Module):
- def __init__(
- self,
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- dilation,
- scale,
- se_bottleneck_dim,
- ):
- super().__init__()
- self.Conv1dReluBn1 = Conv1dReluBn(
- in_channels, out_channels, kernel_size=1, stride=1, padding=0
- )
- self.Res2Conv1dReluBn = Res2Conv1dReluBn(
- out_channels, kernel_size, stride, padding, dilation, scale=scale
- )
- self.Conv1dReluBn2 = Conv1dReluBn(
- out_channels, out_channels, kernel_size=1, stride=1, padding=0
- )
- self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
-
- self.shortcut = None
- if in_channels != out_channels:
- self.shortcut = nn.Conv1d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=1,
- )
-
- def forward(self, x):
- residual = x
- if self.shortcut:
- residual = self.shortcut(x)
-
- x = self.Conv1dReluBn1(x)
- x = self.Res2Conv1dReluBn(x)
- x = self.Conv1dReluBn2(x)
- x = self.SE_Connect(x)
-
- return x + residual
-
-
-""" Attentive weighted mean and standard deviation pooling.
-"""
-
-
-class AttentiveStatsPool(nn.Module):
- def __init__(self, in_dim, attention_channels=128, global_context_att=False):
- super().__init__()
- self.global_context_att = global_context_att
-
- # Use Conv1d with stride == 1 rather than Linear,
- # then we don't need to transpose inputs.
- if global_context_att:
- self.linear1 = nn.Conv1d(
- in_dim * 3, attention_channels, kernel_size=1
- ) # equals W and b in the paper
- else:
- self.linear1 = nn.Conv1d(
- in_dim, attention_channels, kernel_size=1
- ) # equals W and b in the paper
- self.linear2 = nn.Conv1d(
- attention_channels, in_dim, kernel_size=1
- ) # equals V and k in the paper
-
- def forward(self, x):
-
- if self.global_context_att:
- context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
- context_std = torch.sqrt(
- torch.var(x, dim=-1, keepdim=True) + 1e-10
- ).expand_as(x)
- x_in = torch.cat((x, context_mean, context_std), dim=1)
- else:
- x_in = x
-
- # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
- alpha = torch.tanh(self.linear1(x_in))
- # alpha = F.relu(self.linear1(x_in))
- alpha = torch.softmax(self.linear2(alpha), dim=2)
- mean = torch.sum(alpha * x, dim=2)
- residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
- std = torch.sqrt(residuals.clamp(min=1e-9))
- return torch.cat([mean, std], dim=1)
-
-
-class ECAPA_TDNN_WAVLLM(nn.Module):
- def __init__(
- self,
- feat_dim=80,
- channels=512,
- emb_dim=192,
- global_context_att=False,
- sr=16000,
- ssl_model_path=None,
- ):
- super().__init__()
- self.sr = sr
-
- if ssl_model_path is None:
- self.feature_extract = torch.hub.load("s3prl/s3prl", "wavlm_large")
- else:
- self.feature_extract = torch.hub.load(
- os.path.dirname(ssl_model_path),
- "wavlm_local",
- source="local",
- ckpt=ssl_model_path,
- )
-
- if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
- self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
- ):
- self.feature_extract.model.encoder.layers[
- 23
- ].self_attn.fp32_attention = False
- if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
- self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
- ):
- self.feature_extract.model.encoder.layers[
- 11
- ].self_attn.fp32_attention = False
-
- self.feat_num = self.get_feat_num()
- self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
-
- self.instance_norm = nn.InstanceNorm1d(feat_dim)
- # self.channels = [channels] * 4 + [channels * 3]
- self.channels = [channels] * 4 + [1536]
-
- self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
- self.layer2 = SE_Res2Block(
- self.channels[0],
- self.channels[1],
- kernel_size=3,
- stride=1,
- padding=2,
- dilation=2,
- scale=8,
- se_bottleneck_dim=128,
- )
- self.layer3 = SE_Res2Block(
- self.channels[1],
- self.channels[2],
- kernel_size=3,
- stride=1,
- padding=3,
- dilation=3,
- scale=8,
- se_bottleneck_dim=128,
- )
- self.layer4 = SE_Res2Block(
- self.channels[2],
- self.channels[3],
- kernel_size=3,
- stride=1,
- padding=4,
- dilation=4,
- scale=8,
- se_bottleneck_dim=128,
- )
-
- # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
- cat_channels = channels * 3
- self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
- self.pooling = AttentiveStatsPool(
- self.channels[-1],
- attention_channels=128,
- global_context_att=global_context_att,
- )
- self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
- self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
-
- def get_feat_num(self):
- self.feature_extract.eval()
- wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
- with torch.no_grad():
- features = self.feature_extract(wav)
- select_feature = features["hidden_states"]
- if isinstance(select_feature, (list, tuple)):
- return len(select_feature)
- else:
- return 1
-
- def get_feat(self, x):
- with torch.no_grad():
- x = self.feature_extract([sample for sample in x])
-
- x = x["hidden_states"]
- if isinstance(x, (list, tuple)):
- x = torch.stack(x, dim=0)
- else:
- x = x.unsqueeze(0)
- norm_weights = (
- F.softmax(self.feature_weight, dim=-1)
- .unsqueeze(-1)
- .unsqueeze(-1)
- .unsqueeze(-1)
- )
- x = (norm_weights * x).sum(dim=0)
- x = torch.transpose(x, 1, 2) + 1e-6
-
- x = self.instance_norm(x)
- return x
-
- def forward(self, x):
- x = self.get_feat(x)
-
- out1 = self.layer1(x)
- out2 = self.layer2(out1)
- out3 = self.layer3(out2)
- out4 = self.layer4(out3)
-
- out = torch.cat([out2, out3, out4], dim=1)
- out = F.relu(self.conv(out))
- out = self.bn(self.pooling(out))
- out = self.linear(out)
-
- return out
-
-
-if __name__ == "__main__":
- parser = get_parser()
- args = parser.parse_args()
- SIM = SpeakerSimilarity(
- sv_model_path=args.sv_model_path, ssl_model_path=args.ssl_model_path
- )
- score = SIM.score(args.eval_path, args.test_list)
- logging.info(f"SIM-o score: {score:.3f}")
diff --git a/egs/zipvoice/local/evaluate_utmos.py b/egs/zipvoice/local/evaluate_utmos.py
deleted file mode 100644
index 369e139c1..000000000
--- a/egs/zipvoice/local/evaluate_utmos.py
+++ /dev/null
@@ -1,294 +0,0 @@
-"""
-Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system
-adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo
-
-# Download model checkpoints
-wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt
-wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt
-"""
-
-import argparse
-import logging
-import os
-
-import fairseq
-import librosa
-import numpy as np
-import pytorch_lightning as pl
-import soundfile as sf
-import torch
-import torch.nn as nn
-from tqdm import tqdm
-
-logging.basicConfig(level=logging.INFO)
-
-
-def get_parser():
- parser = argparse.ArgumentParser()
-
- parser.add_argument(
- "--wav-path", type=str, help="path of the evaluated speech directory"
- )
- parser.add_argument(
- "--utmos-model-path",
- type=str,
- default="model/huggingface/utmos/utmos.pt",
- help="path of the UTMOS model",
- )
- parser.add_argument(
- "--ssl-model-path",
- type=str,
- default="model/huggingface/utmos/wav2vec_small.pt",
- help="path of the wav2vec SSL model",
- )
- return parser
-
-
-class UTMOSScore:
- """Predicting score for each audio clip."""
-
- def __init__(self, utmos_model_path, ssl_model_path):
- self.sample_rate = 16000
- self.device = (
- torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
- )
- self.model = (
- BaselineLightningModule.load_from_checkpoint(
- utmos_model_path, ssl_model_path=ssl_model_path
- )
- .eval()
- .to(self.device)
- )
-
- def score(self, wavs: torch.Tensor) -> torch.Tensor:
- """
- Args:
- wavs: waveforms to be evaluated. When len(wavs) == 1 or 2,
- the model processes the input as a single audio clip. The model
- performs batch processing when len(wavs) == 3.
- """
- if len(wavs.shape) == 1:
- out_wavs = wavs.unsqueeze(0).unsqueeze(0)
- elif len(wavs.shape) == 2:
- out_wavs = wavs.unsqueeze(0)
- elif len(wavs.shape) == 3:
- out_wavs = wavs
- else:
- raise ValueError("Dimension of input tensor needs to be <= 3.")
- bs = out_wavs.shape[0]
- batch = {
- "wav": out_wavs,
- "domains": torch.zeros(bs, dtype=torch.int).to(self.device),
- "judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288,
- }
- with torch.no_grad():
- output = self.model(batch)
-
- return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3
-
- def score_dir(self, dir, dtype="float32"):
- def _load_speech_task(fname, sample_rate):
-
- wav_data, sr = sf.read(fname, dtype=dtype)
- if sr != sample_rate:
- wav_data = librosa.resample(
- wav_data, orig_sr=sr, target_sr=self.sample_rate
- )
- wav_data = torch.from_numpy(wav_data)
-
- return wav_data
-
- score_lst = []
- for fname in tqdm(os.listdir(dir)):
- speech = _load_speech_task(os.path.join(dir, fname), self.sample_rate)
- speech = speech.to(self.device)
- with torch.no_grad():
- score = self.score(speech)
- score_lst.append(score.item())
- return np.mean(score_lst)
-
-
-def load_ssl_model(ckpt_path="wav2vec_small.pt"):
- SSL_OUT_DIM = 768
- model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
- [ckpt_path]
- )
- ssl_model = model[0]
- ssl_model.remove_pretraining_modules()
- return SSL_model(ssl_model, SSL_OUT_DIM)
-
-
-class BaselineLightningModule(pl.LightningModule):
- def __init__(self, ssl_model_path):
- super().__init__()
- self.construct_model(ssl_model_path)
- self.save_hyperparameters()
-
- def construct_model(self, ssl_model_path):
- self.feature_extractors = nn.ModuleList(
- [
- load_ssl_model(ckpt_path=ssl_model_path),
- DomainEmbedding(3, 128),
- ]
- )
- output_dim = sum(
- [
- feature_extractor.get_output_dim()
- for feature_extractor in self.feature_extractors
- ]
- )
- output_layers = [
- LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim)
- ]
- output_dim = output_layers[-1].get_output_dim()
- output_layers.append(
- Projection(
- hidden_dim=2048,
- activation=torch.nn.ReLU(),
- range_clipping=False,
- input_dim=output_dim,
- )
- )
-
- self.output_layers = nn.ModuleList(output_layers)
-
- def forward(self, inputs):
- outputs = {}
- for feature_extractor in self.feature_extractors:
- outputs.update(feature_extractor(inputs))
- x = outputs
- for output_layer in self.output_layers:
- x = output_layer(x, inputs)
- return x
-
-
-class SSL_model(nn.Module):
- def __init__(self, ssl_model, ssl_out_dim) -> None:
- super(SSL_model, self).__init__()
- self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
-
- def forward(self, batch):
- wav = batch["wav"]
- wav = wav.squeeze(1) # [batches, wav_len]
- res = self.ssl_model(wav, mask=False, features_only=True)
- x = res["x"]
- return {"ssl-feature": x}
-
- def get_output_dim(self):
- return self.ssl_out_dim
-
-
-class DomainEmbedding(nn.Module):
- def __init__(self, n_domains, domain_dim) -> None:
- super().__init__()
- self.embedding = nn.Embedding(n_domains, domain_dim)
- self.output_dim = domain_dim
-
- def forward(self, batch):
- return {"domain-feature": self.embedding(batch["domains"])}
-
- def get_output_dim(self):
- return self.output_dim
-
-
-class LDConditioner(nn.Module):
- """
- Conditions ssl output by listener embedding
- """
-
- def __init__(self, input_dim, judge_dim, num_judges=None):
- super().__init__()
- self.input_dim = input_dim
- self.judge_dim = judge_dim
- self.num_judges = num_judges
- assert num_judges != None
- self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
- # concat [self.output_layer, phoneme features]
-
- self.decoder_rnn = nn.LSTM(
- input_size=self.input_dim + self.judge_dim,
- hidden_size=512,
- num_layers=1,
- batch_first=True,
- bidirectional=True,
- ) # linear?
- self.out_dim = self.decoder_rnn.hidden_size * 2
-
- def get_output_dim(self):
- return self.out_dim
-
- def forward(self, x, batch):
- judge_ids = batch["judge_id"]
- if "phoneme-feature" in x.keys():
- concatenated_feature = torch.cat(
- (
- x["ssl-feature"],
- x["phoneme-feature"]
- .unsqueeze(1)
- .expand(-1, x["ssl-feature"].size(1), -1),
- ),
- dim=2,
- )
- else:
- concatenated_feature = x["ssl-feature"]
- if "domain-feature" in x.keys():
- concatenated_feature = torch.cat(
- (
- concatenated_feature,
- x["domain-feature"]
- .unsqueeze(1)
- .expand(-1, concatenated_feature.size(1), -1),
- ),
- dim=2,
- )
- if judge_ids != None:
- concatenated_feature = torch.cat(
- (
- concatenated_feature,
- self.judge_embedding(judge_ids)
- .unsqueeze(1)
- .expand(-1, concatenated_feature.size(1), -1),
- ),
- dim=2,
- )
- decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
- return decoder_output
-
-
-class Projection(nn.Module):
- def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
- super(Projection, self).__init__()
- self.range_clipping = range_clipping
- output_dim = 1
- if range_clipping:
- self.proj = nn.Tanh()
-
- self.net = nn.Sequential(
- nn.Linear(input_dim, hidden_dim),
- activation,
- nn.Dropout(0.3),
- nn.Linear(hidden_dim, output_dim),
- )
- self.output_dim = output_dim
-
- def forward(self, x, batch):
- output = self.net(x)
-
- # range clipping
- if self.range_clipping:
- return self.proj(output) * 2.0 + 3
- else:
- return output
-
- def get_output_dim(self):
- return self.output_dim
-
-
-if __name__ == "__main__":
- parser = get_parser()
- args = parser.parse_args()
- UTMOS = UTMOSScore(
- utmos_model_path=args.utmos_model_path, ssl_model_path=args.ssl_model_path
- )
- score = UTMOS.score_dir(args.wav_path)
- logging.info(f"UTMOS score: {score:.2f}")
diff --git a/egs/zipvoice/local/evaluate_wer_hubert.py b/egs/zipvoice/local/evaluate_wer_hubert.py
deleted file mode 100644
index d30346e67..000000000
--- a/egs/zipvoice/local/evaluate_wer_hubert.py
+++ /dev/null
@@ -1,172 +0,0 @@
-"""
-Calculate WER with Hubert models.
-"""
-import argparse
-import os
-import re
-from pathlib import Path
-
-import librosa
-import numpy as np
-import soundfile as sf
-import torch
-from jiwer import compute_measures
-from tqdm import tqdm
-from transformers import pipeline
-
-
-def get_parser():
- parser = argparse.ArgumentParser()
-
- parser.add_argument("--wav-path", type=str, help="path of the speech directory")
- parser.add_argument(
- "--decode-path",
- type=str,
- default=None,
- help="path of the output file of WER information",
- )
- parser.add_argument(
- "--model-path",
- type=str,
- default=None,
- help="path of the local hubert model, e.g., model/huggingface/hubert-large-ls960-ft",
- )
- parser.add_argument(
- "--test-list",
- type=str,
- default="test.tsv",
- help="path of the transcript tsv file, where the first column "
- "is the wav name and the last column is the transcript",
- )
- parser.add_argument(
- "--batch-size", type=int, default=16, help="decoding batch size"
- )
- return parser
-
-
-def post_process(text: str):
- text = text.replace("‘", "'")
- text = text.replace("’", "'")
- text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
- text = re.sub(r"\s+", " ", text)
- text = text.strip()
- return text
-
-
-def process_one(hypo, truth):
- truth = post_process(truth)
- hypo = post_process(hypo)
-
- measures = compute_measures(truth, hypo)
- word_num = len(truth.split(" "))
- wer = measures["wer"]
- subs = measures["substitutions"]
- dele = measures["deletions"]
- inse = measures["insertions"]
- return (truth, hypo, wer, subs, dele, inse, word_num)
-
-
-class SpeechEvalDataset(torch.utils.data.Dataset):
- def __init__(self, wav_path: str, test_list: str):
- super().__init__()
- self.wav_name = []
- self.wav_paths = []
- self.transcripts = []
- with Path(test_list).open("r", encoding="utf8") as f:
- meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
- for item in meta:
- self.wav_name.append(item[0])
- self.wav_paths.append(Path(wav_path, item[0] + ".wav"))
- self.transcripts.append(item[-1])
-
- def __len__(self):
- return len(self.wav_paths)
-
- def __getitem__(self, index: int):
- wav, sampling_rate = sf.read(self.wav_paths[index])
- item = {
- "array": librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000),
- "sampling_rate": 16000,
- "reference": self.transcripts[index],
- "wav_name": self.wav_name[index],
- }
- return item
-
-
-def main(test_list, wav_path, model_path, decode_path, batch_size, device):
-
- if model_path is not None:
- pipe = pipeline(
- "automatic-speech-recognition",
- model=model_path,
- device=device,
- tokenizer=model_path,
- )
- else:
- pipe = pipeline(
- "automatic-speech-recognition",
- model="facebook/hubert-large-ls960-ft",
- device=device,
- )
-
- dataset = SpeechEvalDataset(wav_path, test_list)
-
- bar = tqdm(
- pipe(
- dataset,
- generate_kwargs={"language": "english", "task": "transcribe"},
- batch_size=batch_size,
- ),
- total=len(dataset),
- )
-
- wers = []
- inses = []
- deles = []
- subses = []
- word_nums = 0
- if decode_path:
- decode_dir = os.path.dirname(decode_path)
- if not os.path.exists(decode_dir):
- os.makedirs(decode_dir)
- fout = open(decode_path, "w")
- for out in bar:
- wav_name = out["wav_name"][0]
- transcription = post_process(out["text"].strip())
- text_ref = post_process(out["reference"][0].strip())
- truth, hypo, wer, subs, dele, inse, word_num = process_one(
- transcription, text_ref
- )
- if decode_path:
- fout.write(f"{wav_name}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
- wers.append(float(wer))
- inses.append(float(inse))
- deles.append(float(dele))
- subses.append(float(subs))
- word_nums += word_num
-
- wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
- subs = round(np.mean(subses) * 100, 3)
- dele = round(np.mean(deles) * 100, 3)
- inse = round(np.mean(inses) * 100, 3)
- print(f"WER: {wer}%\n")
- if decode_path:
- fout.write(f"WER: {wer}%\n")
- fout.flush()
-
-
-if __name__ == "__main__":
- parser = get_parser()
- args = parser.parse_args()
- if torch.cuda.is_available():
- device = torch.device("cuda", 0)
- else:
- device = torch.device("cpu")
- main(
- args.test_list,
- args.wav_path,
- args.model_path,
- args.decode_path,
- args.batch_size,
- device,
- )
diff --git a/egs/zipvoice/local/evaluate_wer_seedtts.py b/egs/zipvoice/local/evaluate_wer_seedtts.py
deleted file mode 100644
index f7e256387..000000000
--- a/egs/zipvoice/local/evaluate_wer_seedtts.py
+++ /dev/null
@@ -1,181 +0,0 @@
-"""
-Calculate WER with Whisper-large-v3 or Paraformer models,
-following Seed-TTS https://github.com/BytedanceSpeech/seed-tts-eval
-"""
-
-import argparse
-import os
-import string
-
-import numpy as np
-import scipy
-import soundfile as sf
-import torch
-import zhconv
-from funasr import AutoModel
-from jiwer import compute_measures
-from tqdm import tqdm
-from transformers import WhisperForConditionalGeneration, WhisperProcessor
-from zhon.hanzi import punctuation
-
-
-def get_parser():
- parser = argparse.ArgumentParser()
-
- parser.add_argument("--wav-path", type=str, help="path of the speech directory")
- parser.add_argument(
- "--decode-path",
- type=str,
- default=None,
- help="path of the output file of WER information",
- )
- parser.add_argument(
- "--model-path",
- type=str,
- default=None,
- help="path of the local whisper and paraformer model, "
- "e.g., whisper: model/huggingface/whisper-large-v3/, "
- "paraformer: model/huggingface/paraformer-zh/",
- )
- parser.add_argument(
- "--test-list",
- type=str,
- default="test.tsv",
- help="path of the transcript tsv file, where the first column "
- "is the wav name and the last column is the transcript",
- )
- parser.add_argument("--lang", type=str, help="decoded language, zh or en")
- return parser
-
-
-def load_en_model(model_path):
- if model_path is None:
- model_path = "openai/whisper-large-v3"
- processor = WhisperProcessor.from_pretrained(model_path)
- model = WhisperForConditionalGeneration.from_pretrained(model_path)
- return processor, model
-
-
-def load_zh_model(model_path):
- if model_path is None:
- model_path = "paraformer-zh"
- model = AutoModel(model=model_path)
- return model
-
-
-def process_one(hypo, truth, lang):
- punctuation_all = punctuation + string.punctuation
- for x in punctuation_all:
- if x == "'":
- continue
- truth = truth.replace(x, "")
- hypo = hypo.replace(x, "")
-
- truth = truth.replace(" ", " ")
- hypo = hypo.replace(" ", " ")
-
- if lang == "zh":
- truth = " ".join([x for x in truth])
- hypo = " ".join([x for x in hypo])
- elif lang == "en":
- truth = truth.lower()
- hypo = hypo.lower()
- else:
- raise NotImplementedError
-
- measures = compute_measures(truth, hypo)
- word_num = len(truth.split(" "))
- wer = measures["wer"]
- subs = measures["substitutions"]
- dele = measures["deletions"]
- inse = measures["insertions"]
- return (truth, hypo, wer, subs, dele, inse, word_num)
-
-
-def main(test_list, wav_path, model_path, decode_path, lang, device):
- if lang == "en":
- processor, model = load_en_model(model_path)
- model.to(device)
- elif lang == "zh":
- model = load_zh_model(model_path)
- params = []
- for line in open(test_list).readlines():
- line = line.strip()
- items = line.split("\t")
- wav_name, text_ref = items[0], items[-1]
- file_path = os.path.join(wav_path, wav_name + ".wav")
- assert os.path.exists(file_path), f"{file_path}"
-
- params.append((file_path, text_ref))
- wers = []
- inses = []
- deles = []
- subses = []
- word_nums = 0
- if decode_path:
- decode_dir = os.path.dirname(decode_path)
- if not os.path.exists(decode_dir):
- os.makedirs(decode_dir)
- fout = open(decode_path, "w")
- for wav_path, text_ref in tqdm(params):
- if lang == "en":
- wav, sr = sf.read(wav_path)
- if sr != 16000:
- wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr))
- input_features = processor(
- wav, sampling_rate=16000, return_tensors="pt"
- ).input_features
- input_features = input_features.to(device)
- forced_decoder_ids = processor.get_decoder_prompt_ids(
- language="english", task="transcribe"
- )
- predicted_ids = model.generate(
- input_features, forced_decoder_ids=forced_decoder_ids
- )
- transcription = processor.batch_decode(
- predicted_ids, skip_special_tokens=True
- )[0]
- elif lang == "zh":
- res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True)
- transcription = res[0]["text"]
- transcription = zhconv.convert(transcription, "zh-cn")
-
- truth, hypo, wer, subs, dele, inse, word_num = process_one(
- transcription, text_ref, lang
- )
- if decode_path:
- fout.write(f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
- wers.append(float(wer))
- inses.append(float(inse))
- deles.append(float(dele))
- subses.append(float(subs))
- word_nums += word_num
-
- wer_avg = round(np.mean(wers) * 100, 3)
- wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
- subs = round(np.mean(subses) * 100, 3)
- dele = round(np.mean(deles) * 100, 3)
- inse = round(np.mean(inses) * 100, 3)
- print(f"Seed-TTS WER: {wer_avg}%\n")
- print(f"WER: {wer}%\n")
- if decode_path:
- fout.write(f"SeedTTS WER: {wer_avg}%\n")
- fout.write(f"WER: {wer}%\n")
- fout.flush()
-
-
-if __name__ == "__main__":
- parser = get_parser()
- args = parser.parse_args()
- if torch.cuda.is_available():
- device = torch.device("cuda", 0)
- else:
- device = torch.device("cpu")
- main(
- args.test_list,
- args.wav_path,
- args.model_path,
- args.decode_path,
- args.lang,
- device,
- )
diff --git a/egs/zipvoice/local/feature.py b/egs/zipvoice/local/feature.py
deleted file mode 120000
index 08ef7d228..000000000
--- a/egs/zipvoice/local/feature.py
+++ /dev/null
@@ -1 +0,0 @@
-../zipvoice/feature.py
\ No newline at end of file
diff --git a/egs/zipvoice/local/pinyin.txt b/egs/zipvoice/local/pinyin.txt
deleted file mode 100644
index cd8d14dc3..000000000
--- a/egs/zipvoice/local/pinyin.txt
+++ /dev/null
@@ -1,1550 +0,0 @@
-a
-a1
-a2
-a3
-a4
-ai1
-ai2
-ai3
-ai4
-an1
-an2
-an3
-an4
-ang1
-ang2
-ang3
-ang4
-ao1
-ao2
-ao3
-ao4
-ba
-ba1
-ba2
-ba3
-ba4
-bai
-bai1
-bai2
-bai3
-bai4
-ban
-ban1
-ban3
-ban4
-bang1
-bang3
-bang4
-bao1
-bao2
-bao3
-bao4
-bei
-bei1
-bei3
-bei4
-ben1
-ben3
-ben4
-beng
-beng1
-beng2
-beng3
-beng4
-bi1
-bi2
-bi3
-bi4
-bian
-bian1
-bian3
-bian4
-biang2
-biao1
-biao3
-biao4
-bie1
-bie2
-bie3
-bie4
-bin
-bin1
-bin3
-bin4
-bing1
-bing3
-bing4
-bo
-bo1
-bo2
-bo3
-bo4
-bu1
-bu2
-bu3
-bu4
-ca1
-ca3
-ca4
-cai1
-cai2
-cai3
-cai4
-can1
-can2
-can3
-can4
-cang1
-cang2
-cang3
-cang4
-cao1
-cao2
-cao3
-cao4
-ce4
-cei4
-cen1
-cen2
-ceng1
-ceng2
-ceng4
-cha1
-cha2
-cha3
-cha4
-chai1
-chai2
-chai3
-chai4
-chan1
-chan2
-chan3
-chan4
-chang
-chang1
-chang2
-chang3
-chang4
-chao1
-chao2
-chao3
-chao4
-che1
-che2
-che3
-che4
-chen
-chen1
-chen2
-chen3
-chen4
-cheng1
-cheng2
-cheng3
-cheng4
-chi
-chi1
-chi2
-chi3
-chi4
-chong1
-chong2
-chong3
-chong4
-chou1
-chou2
-chou3
-chou4
-chu
-chu1
-chu2
-chu3
-chu4
-chua1
-chua3
-chua4
-chuai1
-chuai2
-chuai3
-chuai4
-chuan1
-chuan2
-chuan3
-chuan4
-chuang1
-chuang2
-chuang3
-chuang4
-chui1
-chui2
-chui3
-chui4
-chun1
-chun2
-chun3
-chuo1
-chuo4
-ci1
-ci2
-ci3
-ci4
-cong1
-cong2
-cong3
-cong4
-cou1
-cou2
-cou3
-cou4
-cu1
-cu2
-cu3
-cu4
-cuan1
-cuan2
-cuan4
-cui
-cui1
-cui3
-cui4
-cun1
-cun2
-cun3
-cun4
-cuo1
-cuo2
-cuo3
-cuo4
-da
-da1
-da2
-da3
-da4
-dai
-dai1
-dai3
-dai4
-dan1
-dan3
-dan4
-dang
-dang1
-dang3
-dang4
-dao1
-dao2
-dao3
-dao4
-de
-de1
-de2
-dei1
-dei3
-den4
-deng1
-deng3
-deng4
-di1
-di2
-di3
-di4
-dia3
-dian1
-dian2
-dian3
-dian4
-diao1
-diao3
-diao4
-die1
-die2
-die3
-die4
-din4
-ding1
-ding3
-ding4
-diu1
-dong1
-dong3
-dong4
-dou1
-dou3
-dou4
-du1
-du2
-du3
-du4
-duan1
-duan3
-duan4
-dui1
-dui3
-dui4
-dun1
-dun3
-dun4
-duo
-duo1
-duo2
-duo3
-duo4
-e
-e1
-e2
-e3
-e4
-ei1
-ei2
-ei3
-ei4
-en1
-en3
-en4
-eng1
-er
-er2
-er3
-er4
-fa
-fa1
-fa2
-fa3
-fa4
-fan1
-fan2
-fan3
-fan4
-fang
-fang1
-fang2
-fang3
-fang4
-fei1
-fei2
-fei3
-fei4
-fen1
-fen2
-fen3
-fen4
-feng1
-feng2
-feng3
-feng4
-fiao4
-fo2
-fou1
-fou2
-fou3
-fu
-fu1
-fu2
-fu3
-fu4
-ga1
-ga2
-ga3
-ga4
-gai1
-gai3
-gai4
-gan1
-gan3
-gan4
-gang1
-gang3
-gang4
-gao1
-gao3
-gao4
-ge1
-ge2
-ge3
-ge4
-gei3
-gen1
-gen2
-gen3
-gen4
-geng1
-geng3
-geng4
-gong
-gong1
-gong3
-gong4
-gou1
-gou3
-gou4
-gu
-gu1
-gu2
-gu3
-gu4
-gua1
-gua2
-gua3
-gua4
-guai1
-guai3
-guai4
-guan1
-guan3
-guan4
-guang
-guang1
-guang3
-guang4
-gui1
-gui3
-gui4
-gun3
-gun4
-guo
-guo1
-guo2
-guo3
-guo4
-ha1
-ha2
-ha3
-ha4
-hai
-hai1
-hai2
-hai3
-hai4
-han
-han1
-han2
-han3
-han4
-hang1
-hang2
-hang3
-hang4
-hao1
-hao2
-hao3
-hao4
-he1
-he2
-he3
-he4
-hei1
-hen1
-hen2
-hen3
-hen4
-heng1
-heng2
-heng4
-hm
-hng
-hong1
-hong2
-hong3
-hong4
-hou1
-hou2
-hou3
-hou4
-hu
-hu1
-hu2
-hu3
-hu4
-hua1
-hua2
-hua4
-huai
-huai2
-huai4
-huan1
-huan2
-huan3
-huan4
-huang
-huang1
-huang2
-huang3
-huang4
-hui
-hui1
-hui2
-hui3
-hui4
-hun1
-hun2
-hun3
-hun4
-huo
-huo1
-huo2
-huo3
-huo4
-ji1
-ji2
-ji3
-ji4
-jia
-jia1
-jia2
-jia3
-jia4
-jian
-jian1
-jian3
-jian4
-jiang
-jiang1
-jiang3
-jiang4
-jiao
-jiao1
-jiao2
-jiao3
-jiao4
-jie
-jie1
-jie2
-jie3
-jie4
-jin1
-jin3
-jin4
-jing
-jing1
-jing3
-jing4
-jiong1
-jiong3
-jiong4
-jiu
-jiu1
-jiu2
-jiu3
-jiu4
-ju
-ju1
-ju2
-ju3
-ju4
-juan1
-juan3
-juan4
-jue1
-jue2
-jue3
-jue4
-jun1
-jun3
-jun4
-ka1
-ka3
-kai1
-kai3
-kai4
-kan1
-kan3
-kan4
-kang1
-kang2
-kang3
-kang4
-kao1
-kao3
-kao4
-ke
-ke1
-ke2
-ke3
-ke4
-kei1
-ken1
-ken3
-ken4
-keng1
-keng3
-kong1
-kong3
-kong4
-kou1
-kou3
-kou4
-ku1
-ku2
-ku3
-ku4
-kua1
-kua3
-kua4
-kuai3
-kuai4
-kuan1
-kuan3
-kuang1
-kuang2
-kuang3
-kuang4
-kui1
-kui2
-kui3
-kui4
-kun
-kun1
-kun3
-kun4
-kuo4
-la
-la1
-la2
-la3
-la4
-lai2
-lai3
-lai4
-lan2
-lan3
-lan4
-lang
-lang1
-lang2
-lang3
-lang4
-lao
-lao1
-lao2
-lao3
-lao4
-le
-le1
-le4
-lei
-lei1
-lei2
-lei3
-lei4
-len4
-leng1
-leng2
-leng3
-leng4
-li
-li1
-li2
-li3
-li4
-lia3
-lian2
-lian3
-lian4
-liang
-liang2
-liang3
-liang4
-liao1
-liao2
-liao3
-liao4
-lie
-lie1
-lie2
-lie3
-lie4
-lin1
-lin2
-lin3
-lin4
-ling
-ling1
-ling2
-ling3
-ling4
-liu1
-liu2
-liu3
-liu4
-lo
-long1
-long2
-long3
-long4
-lou
-lou1
-lou2
-lou3
-lou4
-lu
-lu1
-lu2
-lu3
-lu4
-luan2
-luan3
-luan4
-lun1
-lun2
-lun3
-lun4
-luo
-luo1
-luo2
-luo3
-luo4
-lv2
-lv3
-lv4
-lve3
-lve4
-m1
-m2
-m4
-ma
-ma1
-ma2
-ma3
-ma4
-mai2
-mai3
-mai4
-man1
-man2
-man3
-man4
-mang1
-mang2
-mang3
-mang4
-mao1
-mao2
-mao3
-mao4
-me
-me1
-mei2
-mei3
-mei4
-men
-men1
-men2
-men4
-meng
-meng1
-meng2
-meng3
-meng4
-mi1
-mi2
-mi3
-mi4
-mian2
-mian3
-mian4
-miao1
-miao2
-miao3
-miao4
-mie
-mie1
-mie2
-mie4
-min
-min2
-min3
-ming
-ming2
-ming3
-ming4
-miu3
-miu4
-mo
-mo1
-mo2
-mo3
-mo4
-mou1
-mou2
-mou3
-mou4
-mu2
-mu3
-mu4
-n
-n2
-n3
-n4
-na
-na1
-na2
-na3
-na4
-nai2
-nai3
-nai4
-nan1
-nan2
-nan3
-nan4
-nang
-nang1
-nang2
-nang3
-nang4
-nao1
-nao2
-nao3
-nao4
-ne
-ne2
-ne4
-nei2
-nei3
-nei4
-nen4
-neng2
-neng3
-neng4
-ng
-ng2
-ng3
-ng4
-ni1
-ni2
-ni3
-ni4
-nia1
-nian1
-nian2
-nian3
-nian4
-niang2
-niang3
-niang4
-niao3
-niao4
-nie1
-nie2
-nie3
-nie4
-nin
-nin2
-nin3
-ning2
-ning3
-ning4
-niu1
-niu2
-niu3
-niu4
-nong2
-nong3
-nong4
-nou2
-nou3
-nou4
-nu2
-nu3
-nu4
-nuan2
-nuan3
-nuan4
-nun2
-nun4
-nuo2
-nuo3
-nuo4
-nv2
-nv3
-nv4
-nve4
-o
-o1
-o2
-o3
-o4
-ou
-ou1
-ou2
-ou3
-ou4
-pa1
-pa2
-pa3
-pa4
-pai1
-pai2
-pai3
-pai4
-pan1
-pan2
-pan3
-pan4
-pang1
-pang2
-pang3
-pang4
-pao1
-pao2
-pao3
-pao4
-pei1
-pei2
-pei3
-pei4
-pen1
-pen2
-pen3
-pen4
-peng1
-peng2
-peng3
-peng4
-pi1
-pi2
-pi3
-pi4
-pian1
-pian2
-pian3
-pian4
-piao1
-piao2
-piao3
-piao4
-pie1
-pie3
-pie4
-pin1
-pin2
-pin3
-pin4
-ping1
-ping2
-ping3
-ping4
-po
-po1
-po2
-po3
-po4
-pou1
-pou2
-pou3
-pou4
-pu
-pu1
-pu2
-pu3
-pu4
-qi
-qi1
-qi2
-qi3
-qi4
-qia1
-qia2
-qia3
-qia4
-qian
-qian1
-qian2
-qian3
-qian4
-qiang1
-qiang2
-qiang3
-qiang4
-qiao1
-qiao2
-qiao3
-qiao4
-qie1
-qie2
-qie3
-qie4
-qin1
-qin2
-qin3
-qin4
-qing
-qing1
-qing2
-qing3
-qing4
-qiong1
-qiong2
-qiong4
-qiu1
-qiu2
-qiu3
-qiu4
-qu
-qu1
-qu2
-qu3
-qu4
-quan
-quan1
-quan2
-quan3
-quan4
-que1
-que2
-que4
-qun1
-qun2
-qun3
-ran2
-ran3
-ran4
-rang1
-rang2
-rang3
-rang4
-rao2
-rao3
-rao4
-re2
-re3
-re4
-ren2
-ren3
-ren4
-reng1
-reng2
-reng4
-ri4
-rong
-rong1
-rong2
-rong3
-rong4
-rou2
-rou3
-rou4
-ru
-ru2
-ru3
-ru4
-rua2
-ruan2
-ruan3
-ruan4
-rui2
-rui3
-rui4
-run2
-run3
-run4
-ruo2
-ruo4
-sa
-sa1
-sa3
-sa4
-sai1
-sai3
-sai4
-san
-san1
-san3
-san4
-sang1
-sang3
-sang4
-sao1
-sao3
-sao4
-se1
-se4
-sen1
-sen3
-seng1
-seng4
-sha
-sha1
-sha2
-sha3
-sha4
-shai1
-shai3
-shai4
-shan1
-shan2
-shan3
-shan4
-shang
-shang1
-shang3
-shang4
-shao1
-shao2
-shao3
-shao4
-she1
-she2
-she3
-she4
-shei2
-shen1
-shen2
-shen3
-shen4
-sheng1
-sheng2
-sheng3
-sheng4
-shi
-shi1
-shi2
-shi3
-shi4
-shou
-shou1
-shou2
-shou3
-shou4
-shu1
-shu2
-shu3
-shu4
-shua1
-shua3
-shua4
-shuai1
-shuai3
-shuai4
-shuan1
-shuan4
-shuang1
-shuang3
-shuang4
-shui
-shui2
-shui3
-shui4
-shun3
-shun4
-shuo1
-shuo2
-shuo4
-si
-si1
-si2
-si3
-si4
-song1
-song2
-song3
-song4
-sou1
-sou3
-sou4
-su1
-su2
-su3
-su4
-suan1
-suan3
-suan4
-sui1
-sui2
-sui3
-sui4
-sun1
-sun3
-sun4
-suo
-suo1
-suo2
-suo3
-suo4
-ta
-ta1
-ta2
-ta3
-ta4
-tai
-tai1
-tai2
-tai3
-tai4
-tan1
-tan2
-tan3
-tan4
-tang1
-tang2
-tang3
-tang4
-tao1
-tao2
-tao3
-tao4
-te
-te4
-tei1
-teng1
-teng2
-teng4
-ti
-ti1
-ti2
-ti3
-ti4
-tian1
-tian2
-tian3
-tian4
-tiao
-tiao1
-tiao2
-tiao3
-tiao4
-tie1
-tie2
-tie3
-tie4
-ting1
-ting2
-ting3
-ting4
-tong1
-tong2
-tong3
-tong4
-tou
-tou1
-tou2
-tou3
-tou4
-tu
-tu1
-tu2
-tu3
-tu4
-tuan1
-tuan2
-tuan3
-tuan4
-tui1
-tui2
-tui3
-tui4
-tun1
-tun2
-tun3
-tun4
-tuo1
-tuo2
-tuo3
-tuo4
-wa
-wa1
-wa2
-wa3
-wa4
-wai
-wai1
-wai3
-wai4
-wan1
-wan2
-wan3
-wan4
-wang1
-wang2
-wang3
-wang4
-wei
-wei1
-wei2
-wei3
-wei4
-wen
-wen1
-wen2
-wen3
-wen4
-weng1
-weng3
-weng4
-wo1
-wo3
-wo4
-wong4
-wu
-wu1
-wu2
-wu3
-wu4
-xi1
-xi2
-xi3
-xi4
-xia1
-xia2
-xia3
-xia4
-xian
-xian1
-xian2
-xian3
-xian4
-xiang1
-xiang2
-xiang3
-xiang4
-xiao
-xiao1
-xiao2
-xiao3
-xiao4
-xie1
-xie2
-xie3
-xie4
-xin
-xin1
-xin2
-xin3
-xin4
-xing
-xing1
-xing2
-xing3
-xing4
-xiong1
-xiong2
-xiong3
-xiong4
-xiu1
-xiu2
-xiu3
-xiu4
-xu
-xu1
-xu2
-xu3
-xu4
-xuan1
-xuan2
-xuan3
-xuan4
-xue1
-xue2
-xue3
-xue4
-xun1
-xun2
-xun4
-ya
-ya1
-ya2
-ya3
-ya4
-yan1
-yan2
-yan3
-yan4
-yang
-yang1
-yang2
-yang3
-yang4
-yao1
-yao2
-yao3
-yao4
-ye
-ye1
-ye2
-ye3
-ye4
-yi
-yi1
-yi2
-yi3
-yi4
-yin
-yin1
-yin2
-yin3
-yin4
-ying1
-ying2
-ying3
-ying4
-yo
-yo1
-yong1
-yong2
-yong3
-yong4
-you
-you1
-you2
-you3
-you4
-yu
-yu1
-yu2
-yu3
-yu4
-yuan1
-yuan2
-yuan3
-yuan4
-yue1
-yue2
-yue3
-yue4
-yun
-yun1
-yun2
-yun3
-yun4
-za1
-za2
-za3
-za4
-zai1
-zai3
-zai4
-zan
-zan1
-zan2
-zan3
-zan4
-zang1
-zang3
-zang4
-zao1
-zao2
-zao3
-zao4
-ze
-ze2
-ze4
-zei2
-zen
-zen1
-zen3
-zen4
-zeng1
-zeng3
-zeng4
-zha
-zha1
-zha2
-zha3
-zha4
-zhai1
-zhai2
-zhai3
-zhai4
-zhan1
-zhan2
-zhan3
-zhan4
-zhang
-zhang1
-zhang3
-zhang4
-zhao
-zhao1
-zhao2
-zhao3
-zhao4
-zhe
-zhe1
-zhe2
-zhe3
-zhe4
-zhei4
-zhen1
-zhen2
-zhen3
-zhen4
-zheng1
-zheng3
-zheng4
-zhi
-zhi1
-zhi2
-zhi3
-zhi4
-zhong1
-zhong3
-zhong4
-zhou1
-zhou2
-zhou3
-zhou4
-zhu1
-zhu2
-zhu3
-zhu4
-zhua1
-zhua3
-zhuai1
-zhuai3
-zhuai4
-zhuan1
-zhuan2
-zhuan3
-zhuan4
-zhuang1
-zhuang3
-zhuang4
-zhui1
-zhui3
-zhui4
-zhun1
-zhun3
-zhun4
-zhuo
-zhuo1
-zhuo2
-zhuo4
-zi
-zi1
-zi2
-zi3
-zi4
-zong
-zong1
-zong3
-zong4
-zou1
-zou3
-zou4
-zu1
-zu2
-zu3
-zu4
-zuan1
-zuan3
-zuan4
-zui
-zui1
-zui2
-zui3
-zui4
-zun1
-zun2
-zun3
-zun4
-zuo
-zuo1
-zuo2
-zuo3
-zuo4
-ê1
-ê2
-ê3
-ê4
diff --git a/egs/zipvoice/local/prepare_token_file_emilia.py b/egs/zipvoice/local/prepare_token_file_emilia.py
deleted file mode 100644
index 68af8d397..000000000
--- a/egs/zipvoice/local/prepare_token_file_emilia.py
+++ /dev/null
@@ -1,90 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao,
-# Wei Kang)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-"""
-This file generates the file that maps tokens to IDs.
-"""
-
-import argparse
-import logging
-from pathlib import Path
-from typing import List
-
-from piper_phonemize import get_espeak_map
-from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
-
-
-def get_args():
- parser = argparse.ArgumentParser()
-
- parser.add_argument(
- "--tokens",
- type=Path,
- default=Path("data/tokens_emilia.txt"),
- help="Path to the dict that maps the text tokens to IDs",
- )
-
- parser.add_argument(
- "--pinyin",
- type=Path,
- default=Path("local/pinyin.txt"),
- help="Path to the all unique pinyin",
- )
-
- return parser.parse_args()
-
-
-def get_pinyin_tokens(pinyin: Path) -> List[str]:
- phones = set()
- with open(pinyin, "r") as f:
- for line in f:
- x = line.strip()
- initial = to_initials(x, strict=False)
- # don't want to share tokens with espeak tokens, so use tone3 style
- finals = to_finals_tone3(x, strict=False, neutral_tone_with_five=True)
- if initial != "":
- # don't want to share tokens with espeak tokens, so add a '0' after each initial
- phones.add(initial + "0")
- if finals != "":
- phones.add(finals)
- return sorted(phones)
-
-
-def get_token2id(args):
- """Get a dict that maps token to IDs, and save it to the given filename."""
- all_tokens = get_espeak_map() # token: [token_id]
- all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()}
- # sort by token_id
- all_tokens = sorted(all_tokens.items(), key=lambda x: x[1])
-
- all_pinyin = get_pinyin_tokens(args.pinyin)
- with open(args.tokens, "w", encoding="utf-8") as f:
- for token, token_id in all_tokens:
- f.write(f"{token} {token_id}\n")
- num_espeak_tokens = len(all_tokens)
- for i, pinyin in enumerate(all_pinyin):
- f.write(f"{pinyin} {num_espeak_tokens + i}\n")
-
-
-if __name__ == "__main__":
- formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- logging.basicConfig(format=formatter, level=logging.INFO)
-
- args = get_args()
- get_token2id(args)
diff --git a/egs/zipvoice/local/prepare_token_file_libritts.py b/egs/zipvoice/local/prepare_token_file_libritts.py
deleted file mode 100644
index 374b02613..000000000
--- a/egs/zipvoice/local/prepare_token_file_libritts.py
+++ /dev/null
@@ -1,31 +0,0 @@
-import re
-from collections import Counter
-
-from lhotse import load_manifest_lazy
-
-
-def prepare_tokens(manifest_file, token_file):
- counter = Counter()
- manifest = load_manifest_lazy(manifest_file)
- for cut in manifest:
- line = re.sub(r"\s+", " ", cut.supervisions[0].text)
- counter.update(line)
-
- unique_chars = set(counter.keys())
-
- if "_" in unique_chars:
- unique_chars.remove("_")
-
- sorted_chars = sorted(unique_chars, key=lambda char: counter[char], reverse=True)
-
- result = ["_"] + sorted_chars
-
- with open(token_file, "w", encoding="utf-8") as file:
- for index, char in enumerate(result):
- file.write(f"{char} {index}\n")
-
-
-if __name__ == "__main__":
- manifest_file = "data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz"
- output_token_file = "data/tokens_libritts.txt"
- prepare_tokens(manifest_file, output_token_file)
diff --git a/egs/zipvoice/local/preprocess_emilia.py b/egs/zipvoice/local/preprocess_emilia.py
deleted file mode 100644
index 96a7ff228..000000000
--- a/egs/zipvoice/local/preprocess_emilia.py
+++ /dev/null
@@ -1,155 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao,
-# Zengrui Jin,
-# Wei Kang)
-# 2024 Tsinghua University (authors: Zengrui Jin,)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-"""
-This file reads the texts in given manifest and save the cleaned new cuts.
-"""
-
-import argparse
-import glob
-import logging
-import os
-from concurrent.futures import ProcessPoolExecutor as Pool
-from pathlib import Path
-from typing import List
-
-from lhotse import CutSet, load_manifest_lazy
-from tokenizer import (
- is_alphabet,
- is_chinese,
- is_hangul,
- is_japanese,
- tokenize_by_CJK_char,
-)
-
-
-def get_args():
- parser = argparse.ArgumentParser()
-
- parser.add_argument(
- "--subset",
- type=str,
- help="Subset of emilia, (ZH, EN, etc.)",
- )
-
- parser.add_argument(
- "--jobs",
- type=int,
- default=20,
- help="Number of jobs to processing.",
- )
-
- parser.add_argument(
- "--source-dir",
- type=str,
- default="data/manifests/splits_raw",
- help="The source directory of manifest files.",
- )
-
- parser.add_argument(
- "--dest-dir",
- type=str,
- default="data/manifests/splits",
- help="The destination directory of manifest files.",
- )
-
- return parser.parse_args()
-
-
-def preprocess_emilia(file_name: str, input_dir: Path, output_dir: Path):
- logging.info(f"Processing {file_name}")
- if (output_dir / file_name).is_file():
- logging.info(f"{file_name} exists, skipping.")
- return
-
- def _filter_cut(cut):
- text = cut.supervisions[0].text
- duration = cut.supervisions[0].duration
- chinese = []
- english = []
-
- # only contains chinese and space and alphabets
- clean_chars = []
- for x in text:
- if is_hangul(x):
- logging.warning(f"Delete cut with text containing Korean : {text}")
- return False
- if is_japanese(x):
- logging.warning(f"Delete cut with text containing Japanese : {text}")
- return False
- if is_chinese(x):
- chinese.append(x)
- clean_chars.append(x)
- if is_alphabet(x):
- english.append(x)
- clean_chars.append(x)
- if x == " ":
- clean_chars.append(x)
- if len(english) + len(chinese) == 0:
- logging.warning(f"Delete cut with text has no valid chars : {text}")
- return False
-
- words = tokenize_by_CJK_char("".join(clean_chars))
- for i in range(len(words) - 10):
- if words[i : i + 10].count(words[i]) == 10:
- logging.warning(f"Delete cut with text with too much repeats : {text}")
- return False
- # word speed, 20 - 600 / minute
- if duration < len(words) / 600 * 60 or duration > len(words) / 20 * 60:
- logging.warning(
- f"Delete cut with audio text mismatch, duration : {duration}s, "
- f"words : {len(words)}, text : {text}"
- )
- return False
- return True
-
- try:
- cut_set = load_manifest_lazy(input_dir / file_name)
- cut_set = cut_set.filter(_filter_cut)
- cut_set.to_file(output_dir / file_name)
- except Exception as e:
- logging.error(f"Manifest {file_name} failed with error: {e}")
- os.remove(str(output_dir / file_name))
-
-
-if __name__ == "__main__":
- formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
- logging.basicConfig(format=formatter, level=logging.INFO)
-
- args = get_args()
-
- input_dir = Path(args.source_dir)
- output_dir = Path(args.dest_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
-
- cut_files = glob.glob(f"{args.source_dir}/emilia_cuts_{args.subset}.*.jsonl.gz")
-
- with Pool(max_workers=args.jobs) as pool:
- futures = [
- pool.submit(
- preprocess_emilia, filename.split("/")[-1], input_dir, output_dir
- )
- for filename in cut_files
- ]
- for f in futures:
- f.result()
- f.done()
- logging.info("Processing done.")
diff --git a/egs/zipvoice/local/tokenizer.py b/egs/zipvoice/local/tokenizer.py
deleted file mode 120000
index 024e340cc..000000000
--- a/egs/zipvoice/local/tokenizer.py
+++ /dev/null
@@ -1 +0,0 @@
-../zipvoice/tokenizer.py
\ No newline at end of file
diff --git a/egs/zipvoice/local/validate_manifest.py b/egs/zipvoice/local/validate_manifest.py
deleted file mode 100644
index 68159ae03..000000000
--- a/egs/zipvoice/local/validate_manifest.py
+++ /dev/null
@@ -1,70 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
-# Zengwei Yao)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-This script checks the following assumptions of the generated manifest:
-
-- Single supervision per cut
-
-We will add more checks later if needed.
-
-Usage example:
-
- python3 ./local/validate_manifest.py \
- ./data/spectrogram/ljspeech_cuts_all.jsonl.gz
-
-"""
-
-import argparse
-import logging
-from pathlib import Path
-
-from lhotse import CutSet, load_manifest_lazy
-from lhotse.dataset.speech_synthesis import validate_for_tts
-
-
-def get_args():
- parser = argparse.ArgumentParser()
-
- parser.add_argument(
- "manifest",
- type=Path,
- help="Path to the manifest file",
- )
-
- return parser.parse_args()
-
-
-def main():
- args = get_args()
-
- manifest = args.manifest
- logging.info(f"Validating {manifest}")
-
- assert manifest.is_file(), f"{manifest} does not exist"
- cut_set = load_manifest_lazy(manifest)
- assert isinstance(cut_set, CutSet), type(cut_set)
-
- validate_for_tts(cut_set)
-
-
-if __name__ == "__main__":
- formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-
- logging.basicConfig(format=formatter, level=logging.INFO)
-
- main()
diff --git a/egs/zipvoice/requirements.txt b/egs/zipvoice/requirements.txt
deleted file mode 100644
index cbbe860a5..000000000
--- a/egs/zipvoice/requirements.txt
+++ /dev/null
@@ -1,17 +0,0 @@
---find-links https://k2-fsa.github.io/icefall/piper_phonemize.html
-
-torch
-torchaudio
-huggingface_hub
-lhotse
-safetensors
-vocos
-
-# Normalization
-cn2an
-inflect
-
-# Tokenization
-jieba
-piper_phonemize
-pypinyin
diff --git a/egs/zipvoice/scripts/evaluate.sh b/egs/zipvoice/scripts/evaluate.sh
deleted file mode 100644
index fbc0eed9a..000000000
--- a/egs/zipvoice/scripts/evaluate.sh
+++ /dev/null
@@ -1,102 +0,0 @@
-export CUDA_VISIBLE_DEVICES="0"
-export PYTHONWARNINGS=ignore
-export PYTHONPATH=../../:$PYTHONPATH
-
-# Uncomment this if you have trouble connecting to HuggingFace
-# export HF_ENDPOINT=https://hf-mirror.com
-
-start_stage=1
-end_stage=3
-
-# Models used for SIM-o evaluation.
-# SV model wavlm_large_finetune.pth is downloaded from https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification
-# SSL model wavlm_large.pt is downloaded from https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt
-sv_model_path=model/UniSpeech/wavlm_large_finetune.pth
-wavlm_model_path=model/s3prl/wavlm_large.pt
-
-# Models used for UTMOS evaluation.
-# wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt
-# wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt
-utmos_model_path=model/huggingface/utmos/utmos.pt
-wav2vec_model_path=model/huggingface/utmos/wav2vec_small.pt
-
-
-if [ $start_stage -le 1 ] && [ $end_stage -ge 1 ]; then
-
- echo "=====Evaluate for Seed-TTS test-en======="
- test_list=testset/test_seedtts_en.tsv
- wav_path=results/zipvoice_seedtts_en
-
- echo $wav_path
- echo "-----Computing SIM-o-----"
- python3 local/evaluate_sim.py \
- --sv-model-path ${sv_model_path} \
- --ssl-model-path ${wavlm_model_path} \
- --eval-path ${wav_path} \
- --test-list ${test_list}
-
- echo "-----Computing WER-----"
- python3 local/evaluate_wer_seedtts.py \
- --test-list ${test_list} \
- --wav-path ${wav_path} \
- --lang "en"
-
- echo "-----Computing UTSMOS-----"
- python3 local/evaluate_utmos.py \
- --wav-path ${wav_path} \
- --utmos-model-path ${utmos_model_path} \
- --ssl-model-path ${wav2vec_model_path}
-
-fi
-
-if [ $start_stage -le 2 ] && [ $end_stage -ge 2 ]; then
- echo "=====Evaluate for Seed-TTS test-zh======="
- test_list=testset/test_seedtts_zh.tsv
- wav_path=results/zipvoice_seedtts_zh
-
- echo $wav_path
- echo "-----Computing SIM-o-----"
- python3 local/evaluate_sim.py \
- --sv-model-path ${sv_model_path} \
- --ssl-model-path ${wavlm_model_path} \
- --eval-path ${wav_path} \
- --test-list ${test_list}
-
- echo "-----Computing WER-----"
- python3 local/evaluate_wer_seedtts.py \
- --test-list ${test_list} \
- --wav-path ${wav_path} \
- --lang "zh"
-
- echo "-----Computing UTSMOS-----"
- python3 local/evaluate_utmos.py \
- --wav-path ${wav_path} \
- --utmos-model-path ${utmos_model_path} \
- --ssl-model-path ${wav2vec_model_path}
-fi
-
-if [ $start_stage -le 3 ] && [ $end_stage -ge 3 ]; then
- echo "=====Evaluate for Librispeech test-clean======="
- test_list=testset/test_librispeech_pc_test_clean.tsv
- wav_path=results/zipvoice_librispeech_test_clean
-
- echo $wav_path
- echo "-----Computing SIM-o-----"
- python3 local/evaluate_sim.py \
- --sv-model-path ${sv_model_path} \
- --ssl-model-path ${wavlm_model_path} \
- --eval-path ${wav_path} \
- --test-list ${test_list}
-
- echo "-----Computing WER-----"
- python3 local/evaluate_wer_hubert.py \
- --test-list ${test_list} \
- --wav-path ${wav_path} \
-
- echo "-----Computing UTSMOS-----"
- python3 local/evaluate_utmos.py \
- --wav-path ${wav_path} \
- --utmos-model-path ${utmos_model_path} \
- --ssl-model-path ${wav2vec_model_path}
-
-fi
\ No newline at end of file
diff --git a/egs/zipvoice/scripts/prepare_emilia.sh b/egs/zipvoice/scripts/prepare_emilia.sh
deleted file mode 100755
index bf19ed1a5..000000000
--- a/egs/zipvoice/scripts/prepare_emilia.sh
+++ /dev/null
@@ -1,126 +0,0 @@
-#!/usr/bin/env bash
-
-# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
-export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
-
-set -eou pipefail
-
-stage=0
-stop_stage=5
-sampling_rate=24000
-nj=32
-
-dl_dir=$PWD/download
-
-# All files generated by this script are saved in "data".
-# You can safely remove "data" and rerun this script to regenerate it.
-mkdir -p data
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-log "dl_dir: $dl_dir"
-
-if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
- log "Stage 0: Download data"
-
- # Your download directory should look like this:
- #
- # download/Amphion___Emilia
- # ├── metafile.yaml
- # ├── raw
- # │ ├── DE
- # │ ├── EN
- # │ ├── FR
- # │ ├── JA
- # │ ├── KO
- # │ ├── openemilia_45batches.tar.gz
- # │ ├── openemilia_all.tar.gz
- # │ └── ZH
- # └── README.md
-
- if [ ! -d $dl_dir/Amphion___Emilia/raw ]; then
- log "Please refer https://openxlab.org.cn/datasets/Amphion/Emilia to download the dataset."
- exit(-1)
- fi
-
-fi
-
-if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
- log "Stage 1: Prepare emilia manifests (EN and ZH only)"
- # We assume that you have downloaded the Emilia corpus
- # to $dl_dir/Amphion___Emilia
- # see stage 0 for the directory structure
- mkdir -p data/manifests
- if [ ! -e data/manifests/.emilia.done ]; then
- lhotse prepare emilia --lang en --num-jobs ${nj} $dl_dir/Amphion___Emilia data/manifests
- lhotse prepare emilia --lang zh --num-jobs ${nj} $dl_dir/Amphion___Emilia data/manifests
- touch data/manifests/.emilia.done
- fi
-fi
-
-if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
- log "Stage 2: Preprocess Emilia dataset, mainly for cleaning"
- mkdir -p data/manifests/splits_raw
- if [ ! -e data/manifests/split_raw/.emilia.split.done ]; then
- lhotse split-lazy data/manifests/emilia_cuts_EN.jsonl.gz data/manifests/splits_raw 10000
- lhotse split-lazy data/manifests/emilia_cuts_ZH.jsonl.gz data/manifests/splits_raw 10000
- touch data/manifests/splits_raw/.emilia.split.done
- fi
-
- mkdir -p data/manifests/splits
-
- if [ ! -e data/manifests/splits/.emilia.preprocess.done ]; then
- python local/preprocess_emilia.py --subset EN
- python local/preprocess_emilia.py --subset ZH
- touch data/manifests/splits/.emilia.preprocess.done
- fi
-
-fi
-
-if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
- log "Stage 3: Extract Fbank for Emilia"
- mkdir -p data/fbank/emilia_splits
- if [ ! -e data/fbank/emilia_splits/.emilia.fbank.done ]; then
- # You can speed up the extraction by distributing splits to multiple machines.
- for subset in EN ZH; do
- python local/compute_fbank.py \
- --source-dir data/manifests/splits \
- --dest-dir data/fbank/emilia_splits \
- --dataset emilia \
- --subset ${subset} \
- --splits-cuts 1 \
- --split-begin 0 \
- --split-end 2000 \
- --num-jobs ${nj}
- done
- touch data/fbank/emilia_splits/.emilia.fbank.done
- fi
-
- if [ ! -e data/fbank/emilia_cuts_EN.jsonl.gz ]; then
- log "Combining EN fbank cuts and spliting EN dev set"
- gunzip -c data/fbank/emilia_splits/emilia_cuts_EN.*.jsonl.gz > data/fbank/emilia_cuts_EN.jsonl
- head -n 1500 data/fbank/emilia_cuts_EN.jsonl | gzip -c > data/fbank/emilia_cuts_EN_dev.jsonl.gz
- sed -i '1,1500d' data/fbank/emilia_cuts_EN.jsonl
- gzip data/fbank/emilia_cuts_EN.jsonl
- fi
-
- if [ ! -e data/fbank/emilia_cuts_ZH.jsonl.gz ]; then
- log "Combining ZH fbank cuts and spliting ZH dev set"
- gunzip -c data/fbank/emilia_splits/emilia_cuts_ZH.*.jsonl.gz > data/fbank/emilia_cuts_ZH.jsonl
- head -n 1500 data/fbank/emilia_cuts_ZH.jsonl | gzip -c > data/fbank/emilia_cuts_ZH_dev.jsonl.gz
- sed -i '1,1500d' data/fbank/emilia_cuts_ZH.jsonl
- gzip data/fbank/emilia_cuts_ZH.jsonl
- fi
-
-fi
-
-if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
- log "Stage 4: Generate token file"
- if [ ! -e data/tokens_emilia.txt ]; then
- ./local/prepare_token_file_emilia.py --tokens data/tokens_emilia.txt
- fi
-fi
diff --git a/egs/zipvoice/scripts/prepare_libritts.sh b/egs/zipvoice/scripts/prepare_libritts.sh
deleted file mode 100755
index 6d643145e..000000000
--- a/egs/zipvoice/scripts/prepare_libritts.sh
+++ /dev/null
@@ -1,97 +0,0 @@
-#!/usr/bin/env bash
-
-# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
-export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
-
-set -eou pipefail
-
-stage=0
-stop_stage=5
-sampling_rate=24000
-nj=20
-
-dl_dir=$PWD/download
-
-# All files generated by this script are saved in "data".
-# You can safely remove "data" and rerun this script to regenerate it.
-mkdir -p data
-
-log() {
- # This function is from espnet
- local fname=${BASH_SOURCE[1]##*/}
- echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
-}
-
-log "dl_dir: $dl_dir"
-
-if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
- log "Stage 0: Download data"
-
- # If you have pre-downloaded it to /path/to/LibriTTS,
- # you can create a symlink
- #
- # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS
- #
- if [ ! -d $dl_dir/LibriTTS ]; then
- lhotse download libritts $dl_dir
- fi
-
-fi
-
-if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
- log "Stage 1: Prepare LibriTTS manifest"
- # We assume that you have downloaded the LibriTTS corpus
- # to $dl_dir/LibriTTS
- mkdir -p data/manifests_libritts
- if [ ! -e data/manifests_libritts/.libritts.done ]; then
- lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests
- touch data/manifests/.libritts.done
- fi
-fi
-
-if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
- log "Stage 2: Compute Fbank for LibriTTS"
- mkdir -p data/fbank
-
- if [ ! -e data/fbank/.libritts.done ]; then
- for subset in train-clean-100 train-clean-360 train-other-500 dev-clean test-clean; do
- python local/compute_fbank.py \
- --source-dir data/manifests \
- --dest-dir data/fbank \
- --dataset libritts \
- --subset ${subset} \
- --sampling-rate $sampling_rate \
- --num-jobs ${nj}
- done
- touch data/fbank/.libritts.done
- fi
-
- # Here we shuffle and combine the train-clean-100, train-clean-360 and
- # train-other-500 together to form the training set.
- if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then
- cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \
- <(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \
- <(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \
- shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
- fi
-
- if [ ! -f data/fbank/libritts_cuts_train-clean-460.jsonl.gz ]; then
- cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \
- <(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) | \
- shuf | gzip -c > data/fbank/libritts_cuts_train-clean-460.jsonl.gz
- fi
-
- if [ ! -e data/fbank/.libritts-validated.done ]; then
- log "Validating data/fbank for LibriTTS"
- ./local/validate_manifest.py \
- data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
- touch data/fbank/.libritts-validated.done
- fi
-fi
-
-if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
- log "Stage 3: Generate token file"
- if [ ! -e data/tokens_libritts.txt ]; then
- ./local/prepare_token_file_libritts.py --tokens data/tokens_libritts.txt
- fi
-fi
diff --git a/egs/zipvoice/zipvoice/checkpoint.py b/egs/zipvoice/zipvoice/checkpoint.py
deleted file mode 100644
index e3acd57dd..000000000
--- a/egs/zipvoice/zipvoice/checkpoint.py
+++ /dev/null
@@ -1,142 +0,0 @@
-# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang,
-# Zengwei Yao)
-#
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import logging
-from pathlib import Path
-from typing import Any, Dict, Optional, Union
-
-import torch
-import torch.nn as nn
-from lhotse.dataset.sampling.base import CutSampler
-from torch.cuda.amp import GradScaler
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.optim import Optimizer
-
-# use duck typing for LRScheduler since we have different possibilities, see
-# our class LRScheduler.
-LRSchedulerType = object
-
-
-def save_checkpoint(
- filename: Path,
- model: Union[nn.Module, DDP],
- model_avg: Optional[nn.Module] = None,
- model_ema: Optional[nn.Module] = None,
- params: Optional[Dict[str, Any]] = None,
- optimizer: Optional[Optimizer] = None,
- scheduler: Optional[LRSchedulerType] = None,
- scaler: Optional[GradScaler] = None,
- sampler: Optional[CutSampler] = None,
- rank: int = 0,
-) -> None:
- """Save training information to a file.
-
- Args:
- filename:
- The checkpoint filename.
- model:
- The model to be saved. We only save its `state_dict()`.
- model_avg:
- The stored model averaged from the start of training.
- model_ema:
- The EMA version of model.
- params:
- User defined parameters, e.g., epoch, loss.
- optimizer:
- The optimizer to be saved. We only save its `state_dict()`.
- scheduler:
- The scheduler to be saved. We only save its `state_dict()`.
- scalar:
- The GradScaler to be saved. We only save its `state_dict()`.
- sampler:
- The sampler used in the labeled training dataset. We only
- save its `state_dict()`.
- rank:
- Used in DDP. We save checkpoint only for the node whose
- rank is 0.
- Returns:
- Return None.
- """
- if rank != 0:
- return
-
- logging.info(f"Saving checkpoint to {filename}")
-
- if isinstance(model, DDP):
- model = model.module
-
- checkpoint = {
- "model": model.state_dict(),
- "optimizer": optimizer.state_dict() if optimizer is not None else None,
- "scheduler": scheduler.state_dict() if scheduler is not None else None,
- "grad_scaler": scaler.state_dict() if scaler is not None else None,
- "sampler": sampler.state_dict() if sampler is not None else None,
- }
-
- if model_avg is not None:
- checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
- if model_ema is not None:
- checkpoint["model_ema"] = model_ema.to(torch.float32).state_dict()
-
- if params:
- for k, v in params.items():
- assert k not in checkpoint
- checkpoint[k] = v
-
- torch.save(checkpoint, filename)
-
-
-def load_checkpoint(
- filename: Path,
- model: Optional[nn.Module] = None,
- model_avg: Optional[nn.Module] = None,
- model_ema: Optional[nn.Module] = None,
- strict: bool = False,
-) -> Dict[str, Any]:
- logging.info(f"Loading checkpoint from {filename}")
- checkpoint = torch.load(filename, map_location="cpu")
-
- if model is not None:
-
- if next(iter(checkpoint["model"])).startswith("module."):
- logging.info("Loading checkpoint saved by DDP")
-
- dst_state_dict = model.state_dict()
- src_state_dict = checkpoint["model"]
- for key in dst_state_dict.keys():
- src_key = "{}.{}".format("module", key)
- dst_state_dict[key] = src_state_dict.pop(src_key)
- assert len(src_state_dict) == 0
- model.load_state_dict(dst_state_dict, strict=strict)
- else:
- logging.info("Loading checkpoint")
- model.load_state_dict(checkpoint["model"], strict=strict)
-
- checkpoint.pop("model")
-
- if model_avg is not None and "model_avg" in checkpoint:
- logging.info("Loading averaged model")
- model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
- checkpoint.pop("model_avg")
-
- if model_ema is not None and "model_ema" in checkpoint:
- logging.info("Loading ema model")
- model_ema.load_state_dict(checkpoint["model_ema"], strict=strict)
- checkpoint.pop("model_ema")
-
- return checkpoint
diff --git a/egs/zipvoice/zipvoice/feature.py b/egs/zipvoice/zipvoice/feature.py
deleted file mode 100644
index e7d484d10..000000000
--- a/egs/zipvoice/zipvoice/feature.py
+++ /dev/null
@@ -1,135 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from dataclasses import dataclass
-from typing import Union
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torchaudio
-from lhotse.features.base import FeatureExtractor, register_extractor
-from lhotse.utils import Seconds, compute_num_frames
-
-
-class MelSpectrogramFeatures(nn.Module):
- def __init__(
- self,
- sampling_rate=24000,
- n_mels=100,
- n_fft=1024,
- hop_length=256,
- ):
- super().__init__()
-
- self.mel_spec = torchaudio.transforms.MelSpectrogram(
- sample_rate=sampling_rate,
- n_fft=n_fft,
- hop_length=hop_length,
- n_mels=n_mels,
- center=True,
- power=1,
- )
-
- def forward(self, inp):
- assert len(inp.shape) == 2
-
- mel = self.mel_spec(inp)
- logmel = mel.clamp(min=1e-7).log()
- return logmel
-
-
-@dataclass
-class TorchAudioFbankConfig:
- sampling_rate: int
- n_mels: int
- n_fft: int
- hop_length: int
-
-
-@register_extractor
-class TorchAudioFbank(FeatureExtractor):
-
- name = "TorchAudioFbank"
- config_type = TorchAudioFbankConfig
-
- def __init__(self, config):
- super().__init__(config=config)
-
- def _feature_fn(self, sample):
- fbank = MelSpectrogramFeatures(
- sampling_rate=self.config.sampling_rate,
- n_mels=self.config.n_mels,
- n_fft=self.config.n_fft,
- hop_length=self.config.hop_length,
- )
-
- return fbank(sample)
-
- @property
- def device(self) -> Union[str, torch.device]:
- return self.config.device
-
- def feature_dim(self, sampling_rate: int) -> int:
- return self.config.n_mels
-
- def extract(
- self,
- samples: Union[np.ndarray, torch.Tensor],
- sampling_rate: int,
- ) -> Union[np.ndarray, torch.Tensor]:
- # Check for sampling rate compatibility.
- expected_sr = self.config.sampling_rate
- assert sampling_rate == expected_sr, (
- f"Mismatched sampling rate: extractor expects {expected_sr}, "
- f"got {sampling_rate}"
- )
- is_numpy = False
- if not isinstance(samples, torch.Tensor):
- samples = torch.from_numpy(samples)
- is_numpy = True
-
- if len(samples.shape) == 1:
- samples = samples.unsqueeze(0)
- assert samples.ndim == 2, samples.shape
- assert samples.shape[0] == 1, samples.shape
-
- mel = self._feature_fn(samples).squeeze().t()
-
- assert mel.ndim == 2, mel.shape
- assert mel.shape[1] == self.config.n_mels, mel.shape
-
- num_frames = compute_num_frames(
- samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
- )
-
- if mel.shape[0] > num_frames:
- mel = mel[:num_frames]
- elif mel.shape[0] < num_frames:
- mel = mel.unsqueeze(0)
- mel = torch.nn.functional.pad(
- mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
- ).squeeze(0)
-
- if is_numpy:
- return mel.cpu().numpy()
- else:
- return mel
-
- @property
- def frame_shift(self) -> Seconds:
- return self.config.hop_length / self.config.sampling_rate
diff --git a/egs/zipvoice/zipvoice/generate_averaged_model.py b/egs/zipvoice/zipvoice/generate_averaged_model.py
deleted file mode 100644
index e1b7ca7c6..000000000
--- a/egs/zipvoice/zipvoice/generate_averaged_model.py
+++ /dev/null
@@ -1,209 +0,0 @@
-#!/usr/bin/env python3
-#
-# Copyright 2021-2022 Xiaomi Corporation
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Usage:
-This script loads checkpoints and averages them.
-
-(1) Average ZipVoice models before distill:
- python3 ./zipvoice/generate_averaged_model.py \
- --epoch 11 \
- --avg 4 \
- --distill 0 \
- --token-file data/tokens_emilia.txt \
- --exp-dir ./zipvoice/exp_zipvoice
-
- It will generate a file `epoch-11-avg-14.pt` in the given `exp_dir`.
- You can later load it by `torch.load("epoch-11-avg-4.pt")`.
-
-(2) Average ZipVoice-Distill models (the first stage model):
-
- python3 ./zipvoice/generate_averaged_model.py \
- --iter 60000 \
- --avg 7 \
- --distill 1 \
- --token-file data/tokens_emilia.txt \
- --exp-dir ./zipvoice/exp_zipvoice_distill_1stage
-"""
-
-import argparse
-from pathlib import Path
-
-import torch
-from model import get_distill_model, get_model
-from tokenizer import TokenizerEmilia, TokenizerLibriTTS
-from train_flow import add_model_arguments, get_params
-
-from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
-from icefall.utils import str2bool
-
-
-def get_parser():
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter
- )
-
- parser.add_argument(
- "--epoch",
- type=int,
- default=11,
- help="""It specifies the checkpoint to use for decoding.
- Note: Epoch counts from 1.
- You can specify --avg to use more checkpoints for model averaging.""",
- )
-
- parser.add_argument(
- "--iter",
- type=int,
- default=0,
- help="""If positive, --epoch is ignored and it
- will use the checkpoint exp_dir/checkpoint-iter.pt.
- You can specify --avg to use more checkpoints for model averaging.
- """,
- )
-
- parser.add_argument(
- "--avg",
- type=int,
- default=4,
- help="Number of checkpoints to average. Automatically select "
- "consecutive checkpoints before the checkpoint specified by "
- "'--epoch' or --iter",
- )
-
- parser.add_argument(
- "--exp-dir",
- type=str,
- default="zipvoice/exp_zipvoice",
- help="The experiment dir",
- )
-
- parser.add_argument(
- "--distill",
- type=str2bool,
- default=False,
- help="Whether to use distill model. ",
- )
-
- parser.add_argument(
- "--dataset",
- type=str,
- default="emilia",
- choices=["emilia", "libritts"],
- help="The used training dataset for the model to inference",
- )
-
- add_model_arguments(parser)
-
- return parser
-
-
-@torch.no_grad()
-def main():
- parser = get_parser()
- args = parser.parse_args()
- args.exp_dir = Path(args.exp_dir)
-
- params = get_params()
- params.update(vars(args))
-
- if params.dataset == "emilia":
- tokenizer = TokenizerEmilia(
- token_file=params.token_file, token_type=params.token_type
- )
- elif params.dataset == "libritts":
- tokenizer = TokenizerLibriTTS(
- token_file=params.token_file, token_type=params.token_type
- )
-
- params.vocab_size = tokenizer.vocab_size
- params.pad_id = tokenizer.pad_id
-
- params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
-
- print("Script started")
-
- params.device = torch.device("cpu")
- print(f"Device: {params.device}")
-
- print("About to create model")
- if params.distill:
- model = get_distill_model(params)
- else:
- model = get_model(params)
-
- if params.iter > 0:
- filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
- : params.avg + 1
- ]
- if len(filenames) == 0:
- raise ValueError(
- f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
- )
- elif len(filenames) < params.avg + 1:
- raise ValueError(
- f"Not enough checkpoints ({len(filenames)}) found for"
- f" --iter {params.iter}, --avg {params.avg}"
- )
- filename_start = filenames[-1]
- filename_end = filenames[0]
- print(
- "Calculating the averaged model over iteration checkpoints"
- f" from {filename_start} (excluded) to {filename_end}"
- )
- model.to(params.device)
- model.load_state_dict(
- average_checkpoints_with_averaged_model(
- filename_start=filename_start,
- filename_end=filename_end,
- device=params.device,
- ),
- strict=True,
- )
- else:
- assert params.avg > 0, params.avg
- start = params.epoch - params.avg
- assert start >= 1, start
- filename_start = f"{params.exp_dir}/epoch-{start}.pt"
- filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
- print(
- f"Calculating the averaged model over epoch range from "
- f"{start} (excluded) to {params.epoch}"
- )
- model.to(params.device)
- model.load_state_dict(
- average_checkpoints_with_averaged_model(
- filename_start=filename_start,
- filename_end=filename_end,
- device=params.device,
- ),
- strict=True,
- )
- if params.iter > 0:
- filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
- else:
- filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
- torch.save({"model": model.state_dict()}, filename)
-
- num_param = sum([p.numel() for p in model.parameters()])
- print(f"Number of model parameters: {num_param}")
-
- print("Done!")
-
-
-if __name__ == "__main__":
- main()
diff --git a/egs/zipvoice/zipvoice/infer.py b/egs/zipvoice/zipvoice/infer.py
deleted file mode 100644
index 2819d3c85..000000000
--- a/egs/zipvoice/zipvoice/infer.py
+++ /dev/null
@@ -1,586 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
-# Han Zhu)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-This script loads checkpoints to generate waveforms.
-This script is supposed to be used with the model trained by yourself.
-If you want to use the pre-trained checkpoints provided by us, please refer to zipvoice_infer.py.
-
-(1) Usage with a pre-trained checkpoint:
-
- (a) ZipVoice model before distill:
- python3 zipvoice/infer.py \
- --checkpoint zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
- --distill 0 \
- --token-file "data/tokens_emilia.txt" \
- --test-list test.tsv \
- --res-dir results/test \
- --num-step 16 \
- --guidance-scale 1
-
- (b) ZipVoice-Distill:
- python3 zipvoice/infer.py \
- --checkpoint zipvoice/exp_zipvoice_distill/checkpoint-2000.pt \
- --distill 1 \
- --token-file "data/tokens_emilia.txt" \
- --test-list test.tsv \
- --res-dir results/test_distill \
- --num-step 8 \
- --guidance-scale 3
-
-(2) Usage with a directory of checkpoints (may requires checkpoint averaging):
-
- (a) ZipVoice model before distill:
- python3 flow_match/infer.py \
- --exp-dir zipvoice/exp_zipvoice \
- --epoch 11 \
- --avg 4 \
- --distill 0 \
- --token-file "data/tokens_emilia.txt" \
- --test-list test.tsv \
- --res-dir results \
- --num-step 16 \
- --guidance-scale 1
-
- (b) ZipVoice-Distill:
- python3 flow_match/infer.py \
- --exp-dir zipvoice/exp_zipvoice_distill/ \
- --iter 2000 \
- --avg 0 \
- --distill 1 \
- --token-file "data/tokens_emilia.txt" \
- --test-list test.tsv \
- --res-dir results \
- --num-step 8 \
- --guidance-scale 3
-"""
-
-
-import argparse
-import datetime as dt
-import logging
-import os
-from pathlib import Path
-from typing import Optional
-
-import numpy as np
-import soundfile as sf
-import torch
-import torch.nn as nn
-import torchaudio
-from checkpoint import load_checkpoint
-from feature import TorchAudioFbank, TorchAudioFbankConfig
-from lhotse.utils import fix_random_seed
-from model import get_distill_model, get_model
-from tokenizer import TokenizerEmilia, TokenizerLibriTTS
-from train_flow import add_model_arguments, get_params
-from vocos import Vocos
-
-from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
-from icefall.utils import AttributeDict, setup_logger, str2bool
-
-
-def get_parser():
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter
- )
-
- parser.add_argument(
- "--checkpoint",
- type=str,
- default=None,
- help="The checkpoint for inference. "
- "If it is None, will use checkpoints under exp_dir",
- )
-
- parser.add_argument(
- "--exp-dir",
- type=str,
- default="zipvoice/exp",
- help="The experiment dir",
- )
-
- parser.add_argument(
- "--epoch",
- type=int,
- default=0,
- help="""It specifies the checkpoint to use for decoding.
- Note: Epoch counts from 1.
- You can specify --avg to use more checkpoints for model averaging.""",
- )
-
- parser.add_argument(
- "--iter",
- type=int,
- default=0,
- help="""If positive, --epoch is ignored and it
- will use the checkpoint exp_dir/checkpoint-iter.pt.
- You can specify --avg to use more checkpoints for model averaging.
- """,
- )
-
- parser.add_argument(
- "--avg",
- type=int,
- default=4,
- help="Number of checkpoints to average. Automatically select "
- "consecutive checkpoints before the checkpoint specified by "
- "'--epoch' or '--iter', avg=0 means no avg",
- )
-
- parser.add_argument(
- "--vocoder-path",
- type=str,
- default=None,
- help="The local vocos vocoder path, downloaded from huggingface, "
- "will download the vocodoer from huggingface if it is None.",
- )
-
- parser.add_argument(
- "--distill",
- type=str2bool,
- default=False,
- help="Whether it is a distilled TTS model.",
- )
-
- parser.add_argument(
- "--test-list",
- type=str,
- default=None,
- help="The list of prompt speech, prompt_transcription, "
- "and text to synthesize in the format of "
- "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
- )
-
- parser.add_argument(
- "--res-dir",
- type=str,
- default="results",
- help="Path name of the generated wavs dir",
- )
-
- parser.add_argument(
- "--dataset",
- type=str,
- default="emilia",
- choices=["emilia", "libritts"],
- help="The used training dataset for the model to inference",
- )
-
- parser.add_argument(
- "--guidance-scale",
- type=float,
- default=1.0,
- help="The scale of classifier-free guidance during inference.",
- )
-
- parser.add_argument(
- "--num-step",
- type=int,
- default=16,
- help="The number of sampling steps.",
- )
-
- parser.add_argument(
- "--feat-scale",
- type=float,
- default=0.1,
- help="The scale factor of fbank feature",
- )
-
- parser.add_argument(
- "--speed",
- type=float,
- default=1.0,
- help="Control speech speed, 1.0 means normal, >1.0 means speed up",
- )
-
- parser.add_argument(
- "--t-shift",
- type=float,
- default=0.5,
- help="Shift t to smaller ones if t_shift < 1.0",
- )
-
- parser.add_argument(
- "--target-rms",
- type=float,
- default=0.1,
- help="Target speech normalization rms value",
- )
-
- parser.add_argument(
- "--seed",
- type=int,
- default=666,
- help="Random seed",
- )
-
- add_model_arguments(parser)
-
- return parser
-
-
-def get_vocoder(vocos_local_path: Optional[str] = None):
- if vocos_local_path:
- vocos_local_path = "model/huggingface/vocos-mel-24khz/"
- vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
- state_dict = torch.load(
- f"{vocos_local_path}/pytorch_model.bin",
- weights_only=True,
- map_location="cpu",
- )
- vocoder.load_state_dict(state_dict)
- else:
- vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
- return vocoder
-
-
-def generate_sentence(
- save_path: str,
- prompt_text: str,
- prompt_wav: str,
- text: str,
- model: nn.Module,
- vocoder: nn.Module,
- tokenizer: TokenizerEmilia,
- feature_extractor: TorchAudioFbank,
- device: torch.device,
- num_step: int = 16,
- guidance_scale: float = 1.0,
- speed: float = 1.0,
- t_shift: float = 0.5,
- target_rms: float = 0.1,
- feat_scale: float = 0.1,
- sampling_rate: int = 24000,
-):
- """
- Generate waveform of a text based on a given prompt
- waveform and its transcription.
-
- Args:
- save_path (str): Path to save the generated wav.
- prompt_text (str): Transcription of the prompt wav.
- prompt_wav (str): Path to the prompt wav file.
- text (str): Text to be synthesized into a waveform.
- model (nn.Module): The model used for generation.
- vocoder (nn.Module): The vocoder used to convert features to waveforms.
- tokenizer (TokenizerEmilia): The tokenizer used to convert text to tokens.
- feature_extractor (TorchAudioFbank): The feature extractor used to
- extract acoustic features.
- device (torch.device): The device on which computations are performed.
- num_step (int, optional): Number of steps for decoding. Defaults to 16.
- guidance_scale (float, optional): Scale for classifier-free guidance.
- Defaults to 1.0.
- speed (float, optional): Speed control. Defaults to 1.0.
- t_shift (float, optional): Time shift. Defaults to 0.5.
- target_rms (float, optional): Target RMS for waveform normalization.
- Defaults to 0.1.
- feat_scale (float, optional): Scale for features.
- Defaults to 0.1.
- sampling_rate (int, optional): Sampling rate for the waveform.
- Defaults to 24000.
- Returns:
- metrics (dict): Dictionary containing time and real-time
- factor metrics for processing.
- """
- # Convert text to tokens
- tokens = tokenizer.texts_to_token_ids([text])
- prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
-
- # Load and preprocess prompt wav
- prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
- prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
- if prompt_rms < target_rms:
- prompt_wav = prompt_wav * target_rms / prompt_rms
-
- if prompt_sampling_rate != sampling_rate:
- resampler = torchaudio.transforms.Resample(
- orig_freq=prompt_sampling_rate, new_freq=sampling_rate
- )
- prompt_wav = resampler(prompt_wav)
-
- # Extract features from prompt wav
- prompt_features = feature_extractor.extract(
- prompt_wav, sampling_rate=sampling_rate
- ).to(device)
- prompt_features = prompt_features.unsqueeze(0) * feat_scale
- prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
-
- # Start timing
- start_t = dt.datetime.now()
-
- # Generate features
- (
- pred_features,
- pred_features_lens,
- pred_prompt_features,
- pred_prompt_features_lens,
- ) = model.sample(
- tokens=tokens,
- prompt_tokens=prompt_tokens,
- prompt_features=prompt_features,
- prompt_features_lens=prompt_features_lens,
- speed=speed,
- t_shift=t_shift,
- duration="predict",
- num_step=num_step,
- guidance_scale=guidance_scale,
- )
-
- # Postprocess predicted features
- pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
-
- # Start vocoder processing
- start_vocoder_t = dt.datetime.now()
- wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
-
- # Calculate processing times and real-time factors
- t = (dt.datetime.now() - start_t).total_seconds()
- t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
- t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
- wav_seconds = wav.shape[-1] / sampling_rate
- rtf = t / wav_seconds
- rtf_no_vocoder = t_no_vocoder / wav_seconds
- rtf_vocoder = t_vocoder / wav_seconds
- metrics = {
- "t": t,
- "t_no_vocoder": t_no_vocoder,
- "t_vocoder": t_vocoder,
- "wav_seconds": wav_seconds,
- "rtf": rtf,
- "rtf_no_vocoder": rtf_no_vocoder,
- "rtf_vocoder": rtf_vocoder,
- }
-
- # Adjust wav volume if necessary
- if prompt_rms < target_rms:
- wav = wav * prompt_rms / target_rms
- wav = wav[0].cpu().numpy()
- sf.write(save_path, wav, sampling_rate)
-
- return metrics
-
-
-def generate(
- params: AttributeDict,
- model: nn.Module,
- vocoder: nn.Module,
- tokenizer: TokenizerEmilia,
-):
- total_t = []
- total_t_no_vocoder = []
- total_t_vocoder = []
- total_wav_seconds = []
-
- config = TorchAudioFbankConfig(
- sampling_rate=params.sampling_rate,
- n_mels=100,
- n_fft=1024,
- hop_length=256,
- )
- feature_extractor = TorchAudioFbank(config)
-
- with open(params.test_list, "r") as fr:
- lines = fr.readlines()
-
- for i, line in enumerate(lines):
- wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
- save_path = f"{params.wav_dir}/{wav_name}.wav"
- metrics = generate_sentence(
- save_path=save_path,
- prompt_text=prompt_text,
- prompt_wav=prompt_wav,
- text=text,
- model=model,
- vocoder=vocoder,
- tokenizer=tokenizer,
- feature_extractor=feature_extractor,
- device=params.device,
- num_step=params.num_step,
- guidance_scale=params.guidance_scale,
- speed=params.speed,
- t_shift=params.t_shift,
- target_rms=params.target_rms,
- feat_scale=params.feat_scale,
- sampling_rate=params.sampling_rate,
- )
- print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
- total_t.append(metrics["t"])
- total_t_no_vocoder.append(metrics["t_no_vocoder"])
- total_t_vocoder.append(metrics["t_vocoder"])
- total_wav_seconds.append(metrics["wav_seconds"])
-
- print(f"Average RTF: " f"{np.sum(total_t)/np.sum(total_wav_seconds):.4f}")
- print(
- f"Average RTF w/o vocoder: "
- f"{np.sum(total_t_no_vocoder)/np.sum(total_wav_seconds):.4f}"
- )
- print(
- f"Average RTF vocoder: "
- f"{np.sum(total_t_vocoder)/np.sum(total_wav_seconds):.4f}"
- )
-
-
-@torch.inference_mode()
-def main():
- parser = get_parser()
- args = parser.parse_args()
- args.exp_dir = Path(args.exp_dir)
-
- params = get_params()
- params.update(vars(args))
-
- if params.iter > 0:
- params.suffix = (
- f"wavs-iter-{params.iter}-avg"
- f"-{params.avg}-step-{params.num_step}-scale-{params.guidance_scale}"
- )
- elif params.epoch > 0:
- params.suffix = (
- f"wavs-epoch-{params.epoch}-avg"
- f"-{params.avg}-step-{params.num_step}-scale-{params.guidance_scale}"
- )
- else:
- params.suffix = "wavs"
-
- setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
- logging.info("Decoding started")
-
- if torch.cuda.is_available():
- params.device = torch.device("cuda", 0)
- else:
- params.device = torch.device("cpu")
-
- logging.info(f"Device: {params.device}")
-
- if params.dataset == "emilia":
- tokenizer = TokenizerEmilia(
- token_file=params.token_file, token_type=params.token_type
- )
- elif params.dataset == "libritts":
- tokenizer = TokenizerLibriTTS(
- token_file=params.token_file, token_type=params.token_type
- )
-
- params.vocab_size = tokenizer.vocab_size
- params.pad_id = tokenizer.pad_id
-
- logging.info(params)
- fix_random_seed(params.seed)
-
- logging.info("About to create model")
- if params.distill:
- model = get_distill_model(params)
- else:
- model = get_model(params)
-
- if params.checkpoint:
- load_checkpoint(params.checkpoint, model, strict=True)
- else:
- if params.avg == 0:
- if params.iter > 0:
- load_checkpoint(
- f"{params.exp_dir}/checkpoint-{params.iter}.pt", model, strict=True
- )
- else:
- load_checkpoint(
- f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True
- )
- else:
- if params.iter > 0:
- filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
- : params.avg + 1
- ]
- if len(filenames) == 0:
- raise ValueError(
- f"No checkpoints found for"
- f" --iter {params.iter}, --avg {params.avg}"
- )
- elif len(filenames) < params.avg + 1:
- raise ValueError(
- f"Not enough checkpoints ({len(filenames)}) found for"
- f" --iter {params.iter}, --avg {params.avg}"
- )
- filename_start = filenames[-1]
- filename_end = filenames[0]
- logging.info(
- "Calculating the averaged model over iteration checkpoints"
- f" from {filename_start} (excluded) to {filename_end}"
- )
- model.to(params.device)
- model.load_state_dict(
- average_checkpoints_with_averaged_model(
- filename_start=filename_start,
- filename_end=filename_end,
- device=params.device,
- ),
- strict=True,
- )
- else:
- assert params.avg > 0, params.avg
- start = params.epoch - params.avg
- assert start >= 1, start
- filename_start = f"{params.exp_dir}/epoch-{start}.pt"
- filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
- logging.info(
- f"Calculating the averaged model over epoch range from "
- f"{start} (excluded) to {params.epoch}"
- )
- model.to(params.device)
- model.load_state_dict(
- average_checkpoints_with_averaged_model(
- filename_start=filename_start,
- filename_end=filename_end,
- device=params.device,
- ),
- strict=True,
- )
-
- model = model.to(params.device)
- model.eval()
-
- num_param = sum([p.numel() for p in model.parameters()])
- logging.info(f"Number of model parameters: {num_param}")
-
- vocoder = get_vocoder(params.vocoder_path)
- vocoder = vocoder.to(params.device)
- vocoder.eval()
- num_param = sum([p.numel() for p in vocoder.parameters()])
- logging.info(f"Number of vocoder parameters: {num_param}")
-
- params.wav_dir = f"{params.res_dir}/{params.suffix}"
- os.makedirs(params.wav_dir, exist_ok=True)
-
- assert (
- params.test_list is not None
- ), "Please provide --test-list for speech synthesize."
- generate(
- params=params,
- model=model,
- vocoder=vocoder,
- tokenizer=tokenizer,
- )
-
- logging.info("Done!")
-
-
-if __name__ == "__main__":
- torch.set_num_threads(1)
- torch.set_num_interop_threads(1)
- main()
diff --git a/egs/zipvoice/zipvoice/model.py b/egs/zipvoice/zipvoice/model.py
deleted file mode 100644
index 25c7973b2..000000000
--- a/egs/zipvoice/zipvoice/model.py
+++ /dev/null
@@ -1,578 +0,0 @@
-# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
-# Han Zhu)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import List, Optional
-
-import torch
-import torch.nn as nn
-from scaling import ScheduledFloat
-from solver import EulerSolver
-from torch.nn.parallel import DistributedDataParallel as DDP
-from utils import (
- AttributeDict,
- condition_time_mask,
- get_tokens_index,
- make_pad_mask,
- pad_labels,
- prepare_avg_tokens_durations,
- to_int_tuple,
-)
-from zipformer import TTSZipformer
-
-
-def get_model(params: AttributeDict) -> nn.Module:
- """Get the normal TTS model."""
-
- fm_decoder = get_fm_decoder_model(params)
- text_encoder = get_text_encoder_model(params)
-
- model = TtsModel(
- fm_decoder=fm_decoder,
- text_encoder=text_encoder,
- text_embed_dim=params.text_embed_dim,
- feat_dim=params.feat_dim,
- vocab_size=params.vocab_size,
- pad_id=params.pad_id,
- )
- return model
-
-
-def get_distill_model(params: AttributeDict) -> nn.Module:
- """Get the distillation TTS model."""
-
- fm_decoder = get_fm_decoder_model(params, distill=True)
- text_encoder = get_text_encoder_model(params)
-
- model = DistillTTSModelTrainWrapper(
- fm_decoder=fm_decoder,
- text_encoder=text_encoder,
- text_embed_dim=params.text_embed_dim,
- feat_dim=params.feat_dim,
- vocab_size=params.vocab_size,
- pad_id=params.pad_id,
- )
- return model
-
-
-def get_fm_decoder_model(params: AttributeDict, distill: bool = False) -> nn.Module:
- """Get the Zipformer-based FM decoder model."""
-
- encoder = TTSZipformer(
- in_dim=params.feat_dim * 3,
- out_dim=params.feat_dim,
- downsampling_factor=to_int_tuple(params.fm_decoder_downsampling_factor),
- num_encoder_layers=to_int_tuple(params.fm_decoder_num_layers),
- cnn_module_kernel=to_int_tuple(params.fm_decoder_cnn_module_kernel),
- encoder_dim=params.fm_decoder_dim,
- feedforward_dim=params.fm_decoder_feedforward_dim,
- num_heads=params.fm_decoder_num_heads,
- query_head_dim=params.query_head_dim,
- pos_head_dim=params.pos_head_dim,
- value_head_dim=params.value_head_dim,
- pos_dim=params.pos_dim,
- dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
- warmup_batches=4000.0,
- use_time_embed=True,
- time_embed_dim=192,
- use_guidance_scale_embed=distill,
- )
- return encoder
-
-
-def get_text_encoder_model(params: AttributeDict) -> nn.Module:
- """Get the Zipformer-based text encoder model."""
-
- encoder = TTSZipformer(
- in_dim=params.text_embed_dim,
- out_dim=params.feat_dim,
- downsampling_factor=to_int_tuple(params.text_encoder_downsampling_factor),
- num_encoder_layers=to_int_tuple(params.text_encoder_num_layers),
- cnn_module_kernel=to_int_tuple(params.text_encoder_cnn_module_kernel),
- encoder_dim=params.text_encoder_dim,
- feedforward_dim=params.text_encoder_feedforward_dim,
- num_heads=params.text_encoder_num_heads,
- query_head_dim=params.query_head_dim,
- pos_head_dim=params.pos_head_dim,
- value_head_dim=params.value_head_dim,
- pos_dim=params.pos_dim,
- dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
- warmup_batches=4000.0,
- use_time_embed=False,
- )
- return encoder
-
-
-class TtsModel(nn.Module):
- """The normal TTS model."""
-
- def __init__(
- self,
- fm_decoder: nn.Module,
- text_encoder: nn.Module,
- text_embed_dim: int,
- feat_dim: int,
- vocab_size: int,
- pad_id: int = 0,
- ):
- """
- Args:
- fm_decoder: the flow-matching encoder model, inputs are the
- input condition embeddings and noisy acoustic features,
- outputs are better acoustic features.
- text_encoder: the text encoder model. input are text
- embeddings, output are contextualized text embeddings.
- text_embed_dim: dimension of text embedding.
- feat_dim: dimension of acoustic features.
- vocab_size: vocabulary size.
- pad_id: padding id.
- """
- super().__init__()
-
- self.feat_dim = feat_dim
- self.text_embed_dim = text_embed_dim
- self.pad_id = pad_id
-
- self.fm_decoder = fm_decoder
-
- self.text_encoder = text_encoder
-
- self.embed = nn.Embedding(vocab_size, text_embed_dim)
-
- self.distill = False
-
- def forward_fm_decoder(
- self,
- t: torch.Tensor,
- xt: torch.Tensor,
- text_condition: torch.Tensor,
- speech_condition: torch.Tensor,
- padding_mask: Optional[torch.Tensor] = None,
- guidance_scale: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- """Compute velocity.
- Args:
- t: A tensor of shape (N, 1, 1) or a tensor of a float,
- in the range of (0, 1).
- xt: the input of the current timestep, including condition
- embeddings and noisy acoustic features.
- text_condition: the text condition embeddings, with the
- shape (batch, seq_len, emb_dim).
- speech_condition: the speech condition embeddings, with the
- shape (batch, seq_len, emb_dim).
- padding_mask: The mask for padding, True means masked
- position, with the shape (N, T).
- guidance_scale: The guidance scale in classifier-free guidance,
- which is a tensor of shape (N, 1, 1) or a tensor of a float.
-
- Returns:
- predicted velocity, with the shape (batch, seq_len, emb_dim).
- """
- assert t.dim() in (0, 3)
- # Handle t with the shape (N, 1, 1):
- # squeeze the last dimension if it's size is 1.
- while t.dim() > 1 and t.size(-1) == 1:
- t = t.squeeze(-1)
- if guidance_scale is not None:
- while guidance_scale.dim() > 1 and guidance_scale.size(-1) == 1:
- guidance_scale = guidance_scale.squeeze(-1)
- # Handle t with a single value: expand to the size of batch size.
- if t.dim() == 0:
- t = t.repeat(xt.shape[0])
- if guidance_scale is not None and guidance_scale.dim() == 0:
- guidance_scale = guidance_scale.repeat(xt.shape[0])
-
- xt = torch.cat([xt, text_condition, speech_condition], dim=2)
- vt = self.fm_decoder(
- x=xt, t=t, padding_mask=padding_mask, guidance_scale=guidance_scale
- )
- return vt
-
- def forward_text_embed(
- self,
- tokens: List[List[int]],
- ):
- """
- Get the text embeddings.
- Args:
- tokens: a list of list of token ids.
- Returns:
- embed: the text embeddings, shape (batch, seq_len, emb_dim).
- tokens_lens: the length of each token sequence, shape (batch,).
- """
- device = (
- self.device if isinstance(self, DDP) else next(self.parameters()).device
- )
- tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S)
- embed = self.embed(tokens_padded) # (B, S, C)
- tokens_lens = torch.tensor(
- [len(token) for token in tokens], dtype=torch.int64, device=device
- )
- tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S)
-
- embed = self.text_encoder(
- x=embed, t=None, padding_mask=tokens_padding_mask
- ) # (B, S, C)
- return embed, tokens_lens
-
- def forward_text_condition(
- self,
- embed: torch.Tensor,
- tokens_lens: torch.Tensor,
- features_lens: torch.Tensor,
- ):
- """
- Get the text condition with the same length of the acoustic feature.
- Args:
- embed: the text embeddings, shape (batch, token_seq_len, emb_dim).
- tokens_lens: the length of each token sequence, shape (batch,).
- features_lens: the length of each acoustic feature sequence,
- shape (batch,).
- Returns:
- text_condition: the text condition, shape
- (batch, feature_seq_len, emb_dim).
- padding_mask: the padding mask of text condition, shape
- (batch, feature_seq_len).
- """
-
- num_frames = int(features_lens.max())
-
- padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T)
-
- tokens_durations = prepare_avg_tokens_durations(features_lens, tokens_lens)
-
- tokens_index = get_tokens_index(tokens_durations, num_frames).to(
- embed.device
- ) # (B, T)
-
- text_condition = torch.gather(
- embed,
- dim=1,
- index=tokens_index.unsqueeze(-1).expand(
- embed.size(0), num_frames, embed.size(-1)
- ),
- ) # (B, T, F)
- return text_condition, padding_mask
-
- def forward_text_train(
- self,
- tokens: List[List[int]],
- features_lens: torch.Tensor,
- ):
- """
- Process text for training, given text tokens and real feature lengths.
- """
- embed, tokens_lens = self.forward_text_embed(tokens)
- text_condition, padding_mask = self.forward_text_condition(
- embed, tokens_lens, features_lens
- )
- return (
- text_condition,
- padding_mask,
- )
-
- def forward_text_inference_gt_duration(
- self,
- tokens: List[List[int]],
- features_lens: torch.Tensor,
- prompt_tokens: List[List[int]],
- prompt_features_lens: torch.Tensor,
- ):
- """
- Process text for inference, given text tokens, real feature lengths and prompts.
- """
- tokens = [
- prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
- ]
- features_lens = prompt_features_lens + features_lens
- embed, tokens_lens = self.forward_text_embed(tokens)
- text_condition, padding_mask = self.forward_text_condition(
- embed, tokens_lens, features_lens
- )
- return text_condition, padding_mask
-
- def forward_text_inference_ratio_duration(
- self,
- tokens: List[List[int]],
- prompt_tokens: List[List[int]],
- prompt_features_lens: torch.Tensor,
- speed: float,
- ):
- """
- Process text for inference, given text tokens and prompts,
- feature lengths are predicted with the ratio of token numbers.
- """
- device = (
- self.device if isinstance(self, DDP) else next(self.parameters()).device
- )
-
- cat_tokens = [
- prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
- ]
-
- prompt_tokens_lens = torch.tensor(
- [len(token) for token in prompt_tokens], dtype=torch.int64, device=device
- )
-
- cat_embed, cat_tokens_lens = self.forward_text_embed(cat_tokens)
-
- features_lens = torch.ceil(
- (prompt_features_lens / prompt_tokens_lens * cat_tokens_lens / speed)
- ).to(dtype=torch.int64)
-
- text_condition, padding_mask = self.forward_text_condition(
- cat_embed, cat_tokens_lens, features_lens
- )
- return text_condition, padding_mask
-
- def forward(
- self,
- tokens: List[List[int]],
- features: torch.Tensor,
- features_lens: torch.Tensor,
- noise: torch.Tensor,
- t: torch.Tensor,
- condition_drop_ratio: float = 0.0,
- ) -> torch.Tensor:
- """Forward pass of the model for training.
- Args:
- tokens: a list of list of token ids.
- features: the acoustic features, with the shape (batch, seq_len, feat_dim).
- features_lens: the length of each acoustic feature sequence, shape (batch,).
- noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
- t: the time step, with the shape (batch, 1, 1).
- condition_drop_ratio: the ratio of dropped text condition.
- Returns:
- fm_loss: the flow-matching loss.
- """
-
- (text_condition, padding_mask,) = self.forward_text_train(
- tokens=tokens,
- features_lens=features_lens,
- )
-
- speech_condition_mask = condition_time_mask(
- features_lens=features_lens,
- mask_percent=(0.7, 1.0),
- max_len=features.size(1),
- )
- speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
-
- if condition_drop_ratio > 0.0:
- drop_mask = (
- torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
- > condition_drop_ratio
- )
- text_condition = text_condition * drop_mask
-
- xt = features * t + noise * (1 - t)
- ut = features - noise # (B, T, F)
-
- vt = self.forward_fm_decoder(
- t=t,
- xt=xt,
- text_condition=text_condition,
- speech_condition=speech_condition,
- padding_mask=padding_mask,
- )
-
- loss_mask = speech_condition_mask & (~padding_mask)
- fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
-
- return fm_loss
-
- def sample(
- self,
- tokens: List[List[int]],
- prompt_tokens: List[List[int]],
- prompt_features: torch.Tensor,
- prompt_features_lens: torch.Tensor,
- features_lens: Optional[torch.Tensor] = None,
- speed: float = 1.0,
- t_shift: float = 1.0,
- duration: str = "predict",
- num_step: int = 5,
- guidance_scale: float = 0.5,
- ) -> torch.Tensor:
- """
- Generate acoustic features, given text tokens, prompts feature
- and prompt transcription's text tokens.
- Args:
- tokens: a list of list of text tokens.
- prompt_tokens: a list of list of prompt tokens.
- prompt_features: the prompt feature with the shape
- (batch_size, seq_len, feat_dim).
- prompt_features_lens: the length of each prompt feature,
- with the shape (batch_size,).
- features_lens: the length of the predicted eature, with the
- shape (batch_size,). It is used only when duration is "real".
- duration: "real" or "predict". If "real", the predicted
- feature length is given by features_lens.
- num_step: the number of steps to use in the ODE solver.
- guidance_scale: the guidance scale for classifier-free guidance.
- distill: whether to use the distillation model for sampling.
- """
-
- assert duration in ["real", "predict"]
-
- if duration == "predict":
- (
- text_condition,
- padding_mask,
- ) = self.forward_text_inference_ratio_duration(
- tokens=tokens,
- prompt_tokens=prompt_tokens,
- prompt_features_lens=prompt_features_lens,
- speed=speed,
- )
- else:
- assert features_lens is not None
- text_condition, padding_mask = self.forward_text_inference_gt_duration(
- tokens=tokens,
- features_lens=features_lens,
- prompt_tokens=prompt_tokens,
- prompt_features_lens=prompt_features_lens,
- )
- batch_size, num_frames, _ = text_condition.shape
-
- speech_condition = torch.nn.functional.pad(
- prompt_features, (0, 0, 0, num_frames - prompt_features.size(1))
- ) # (B, T, F)
-
- # False means speech condition positions.
- speech_condition_mask = make_pad_mask(prompt_features_lens, num_frames)
- speech_condition = torch.where(
- speech_condition_mask.unsqueeze(-1),
- torch.zeros_like(speech_condition),
- speech_condition,
- )
-
- x0 = torch.randn(
- batch_size, num_frames, self.feat_dim, device=text_condition.device
- )
- solver = EulerSolver(self, distill=self.distill, func_name="forward_fm_decoder")
-
- x1 = solver.sample(
- x=x0,
- text_condition=text_condition,
- speech_condition=speech_condition,
- padding_mask=padding_mask,
- num_step=num_step,
- guidance_scale=guidance_scale,
- t_shift=t_shift,
- )
- x1_wo_prompt_lens = (~padding_mask).sum(-1) - prompt_features_lens
- x1_prompt = torch.zeros(
- x1.size(0), prompt_features_lens.max(), x1.size(2), device=x1.device
- )
- x1_wo_prompt = torch.zeros(
- x1.size(0), x1_wo_prompt_lens.max(), x1.size(2), device=x1.device
- )
- for i in range(x1.size(0)):
- x1_wo_prompt[i, : x1_wo_prompt_lens[i], :] = x1[
- i,
- prompt_features_lens[i] : prompt_features_lens[i]
- + x1_wo_prompt_lens[i],
- ]
- x1_prompt[i, : prompt_features_lens[i], :] = x1[
- i, : prompt_features_lens[i]
- ]
-
- return x1_wo_prompt, x1_wo_prompt_lens, x1_prompt, prompt_features_lens
-
- def sample_intermediate(
- self,
- tokens: List[List[int]],
- features: torch.Tensor,
- features_lens: torch.Tensor,
- noise: torch.Tensor,
- speech_condition_mask: torch.Tensor,
- t_start: torch.Tensor,
- t_end: torch.Tensor,
- num_step: int = 1,
- guidance_scale: torch.Tensor = None,
- ) -> torch.Tensor:
- """
- Generate acoustic features in intermediate timesteps.
- Args:
- tokens: List of list of token ids.
- features: The acoustic features, with the shape (batch, seq_len, feat_dim).
- features_lens: The length of each acoustic feature sequence,
- with the shape (batch,).
- noise: The initial noise, with the shape (batch, seq_len, feat_dim).
- speech_condition_mask: The mask for speech condition, True means
- non-condition positions, with the shape (batch, seq_len).
- t_start: The start timestep, with the shape (batch, 1, 1).
- t_end: The end timestep, with the shape (batch, 1, 1).
- num_step: The number of steps for sampling.
- guidance_scale: The scale for classifier-free guidance inference,
- with the shape (batch, 1, 1).
- distill: Whether to use distillation model.
- """
- (text_condition, padding_mask,) = self.forward_text_train(
- tokens=tokens,
- features_lens=features_lens,
- )
-
- speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
-
- solver = EulerSolver(self, distill=self.distill, func_name="forward_fm_decoder")
-
- x_t_end = solver.sample(
- x=noise,
- text_condition=text_condition,
- speech_condition=speech_condition,
- padding_mask=padding_mask,
- num_step=num_step,
- guidance_scale=guidance_scale,
- t_start=t_start,
- t_end=t_end,
- )
- x_t_end_lens = (~padding_mask).sum(-1)
- return x_t_end, x_t_end_lens
-
-
-class DistillTTSModelTrainWrapper(TtsModel):
- """Wrapper for training the distilled TTS model."""
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.distill = True
-
- def forward(
- self,
- tokens: List[List[int]],
- features: torch.Tensor,
- features_lens: torch.Tensor,
- noise: torch.Tensor,
- speech_condition_mask: torch.Tensor,
- t_start: torch.Tensor,
- t_end: torch.Tensor,
- num_step: int = 1,
- guidance_scale: torch.Tensor = None,
- ) -> torch.Tensor:
-
- return self.sample_intermediate(
- tokens=tokens,
- features=features,
- features_lens=features_lens,
- noise=noise,
- speech_condition_mask=speech_condition_mask,
- t_start=t_start,
- t_end=t_end,
- num_step=num_step,
- guidance_scale=guidance_scale,
- )
diff --git a/egs/zipvoice/zipvoice/optim.py b/egs/zipvoice/zipvoice/optim.py
deleted file mode 100644
index daf17556a..000000000
--- a/egs/zipvoice/zipvoice/optim.py
+++ /dev/null
@@ -1,1256 +0,0 @@
-# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
-#
-# See ../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import contextlib
-import logging
-import random
-from collections import defaultdict
-from typing import Dict, List, Optional, Tuple, Union
-
-import torch
-from lhotse.utils import fix_random_seed
-from torch import Tensor
-from torch.optim import Optimizer
-
-
-class BatchedOptimizer(Optimizer):
- """
- This class adds to class Optimizer the capability to optimize parameters in batches:
- it will stack the parameters and their grads for you so the optimizer can work
- on tensors with an extra leading dimension. This is intended for speed with GPUs,
- as it reduces the number of kernels launched in the optimizer.
-
- Args:
- params:
- """
-
- def __init__(self, params, defaults):
- super(BatchedOptimizer, self).__init__(params, defaults)
-
- @contextlib.contextmanager
- def batched_params(self, param_group, group_params_names):
- """
- This function returns (technically, yields) a list of
- of tuples (p, state), where
- p is a `fake` parameter that is stacked (over axis 0) from real parameters
- that share the same shape, and its gradient is also stacked;
- `state` is the state corresponding to this batch of parameters
- (it will be physically located in the "state" for one of the real
- parameters, the last one that has any particular shape and dtype).
-
- This function is decorated as a context manager so that it can
- write parameters back to their "real" locations.
-
- The idea is, instead of doing:
-
- for p in group["params"]:
- state = self.state[p]
- ...
-
- you can do:
-
- with self.batched_params(group["params"]) as batches:
- for p, state, p_names in batches:
- ...
-
-
- Args:
- group: a parameter group, which is a list of parameters; should be
- one of self.param_groups.
- group_params_names: name for each parameter in group,
- which is List[str].
- """
- batches = defaultdict(
- list
- ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
- batches_names = defaultdict(
- list
- ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
-
- assert len(param_group) == len(group_params_names)
- for p, named_p in zip(param_group, group_params_names):
- key = (str(p.dtype), *p.shape)
- batches[key].append(p)
- batches_names[key].append(named_p)
-
- batches_names_keys = list(batches_names.keys())
- sorted_idx = sorted(
- range(len(batches_names)), key=lambda i: batches_names_keys[i]
- )
- batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
- batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
-
- stacked_params_dict = dict()
-
- # turn batches into a list, in deterministic order.
- # tuples will contain tuples of (stacked_param, state, stacked_params_names),
- # one for each batch in `batches`.
- tuples = []
-
- for batch, batch_names in zip(batches, batches_names):
- p = batch[0]
- # we arbitrarily store the state in the
- # state corresponding to the 1st parameter in the
- # group. class Optimizer will take care of saving/loading state.
- state = self.state[p]
- p_stacked = torch.stack(batch)
- grad = torch.stack(
- [torch.zeros_like(p) if p.grad is None else p.grad for p in batch]
- )
- p_stacked.grad = grad
- stacked_params_dict[key] = p_stacked
- tuples.append((p_stacked, state, batch_names))
-
- yield tuples # <-- calling code will do the actual optimization here!
-
- for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
- for i, p in enumerate(batch): # batch is list of Parameter
- p.copy_(stacked_params[i])
-
-
-def basic_step(group, p, state, grad):
- # computes basic Adam update using beta2 (dividing by gradient stddev) only. no momentum yet.
- lr = group["lr"]
- if p.numel() == p.shape[0]:
- lr = lr * group["scalar_lr_scale"]
- beta2 = group["betas"][1]
- eps = group["eps"]
- # p shape: (batch_size,) or (batch_size, 1, [1,..])
- try:
- exp_avg_sq = state[
- "exp_avg_sq"
- ] # shape: (batch_size,) or (batch_size, 1, [1,..])
- except KeyError:
- exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
- state["exp_avg_sq"] = exp_avg_sq
-
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
-
- # bias_correction2 is like in Adam.
- # slower update at the start will help stability anyway.
- bias_correction2 = 1 - beta2 ** (state["step"] + 1)
- if bias_correction2 < 0.99:
- # note: not in-place.
- exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
- denom = exp_avg_sq.sqrt().add_(eps)
-
- return -lr * grad / denom
-
-
-def scaling_step(group, p, state, grad):
- delta = basic_step(group, p, state, grad)
- if p.numel() == p.shape[0]:
- return delta # there is no scaling for scalar parameters. (p.shape[0] is the batch of parameters.)
-
- step = state["step"]
- size_update_period = group["size_update_period"]
-
- try:
- param_rms = state["param_rms"]
- scale_grads = state["scale_grads"]
- scale_exp_avg_sq = state["scale_exp_avg_sq"]
- except KeyError:
- # we know p.ndim > 1 because we'd have returned above if not, so don't worry
- # about the speial case of dim=[] that pytorch treats inconsistently.
- param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
- param_rms = param_rms.to(torch.float)
- scale_exp_avg_sq = torch.zeros_like(param_rms)
- scale_grads = torch.zeros(
- size_update_period, *param_rms.shape, dtype=torch.float, device=p.device
- )
- state["param_rms"] = param_rms
- state["scale_grads"] = scale_grads
- state["scale_exp_avg_sq"] = scale_exp_avg_sq
-
- # on every step, update the gradient w.r.t. the scale of the parameter, we
- # store these as a batch and periodically update the size (for speed only, to
- # avoid too many operations).
- scale_grads[step % size_update_period] = (p * grad).sum(
- dim=list(range(1, p.ndim)), keepdim=True
- )
-
- # periodically recompute the value of param_rms.
- if step % size_update_period == size_update_period - 1:
- param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
-
- param_min_rms = group["param_min_rms"]
-
- # scale the step size by param_rms. This is the most important "scaling" part of
- # ScaledAdam
- delta *= param_rms.clamp(min=param_min_rms)
-
- if step % size_update_period == size_update_period - 1 and step > 0:
- # This block updates the size of parameter by adding a step ("delta") value in
- # the direction of either shrinking or growing it.
- beta2 = group["betas"][1]
- size_lr = group["lr"] * group["scalar_lr_scale"]
- param_max_rms = group["param_max_rms"]
- eps = group["eps"]
- batch_size = p.shape[0]
- # correct beta2 for the size update period: we will have
- # faster decay at this level.
- beta2_corr = beta2**size_update_period
- scale_exp_avg_sq.mul_(beta2_corr).add_(
- (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
- alpha=1 - beta2_corr,
- ) # shape is (batch_size, 1, 1, ...)
-
- # The 1st time we reach here is when size_step == 1.
- size_step = (step + 1) // size_update_period
- bias_correction2 = 1 - beta2_corr**size_step
-
- denom = scale_exp_avg_sq.sqrt() + eps
-
- scale_step = (
- -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
- )
-
- is_too_small = param_rms < param_min_rms
-
- # when the param gets too small, just don't shrink it any further.
- scale_step.masked_fill_(is_too_small, 0.0)
-
- # The following may help prevent instability: don't allow the scale step to be too large in
- # either direction.
- scale_step.clamp_(min=-0.1, max=0.1)
-
- # and ensure the parameter rms after update never exceeds param_max_rms.
- # We have to look at the trained model for parameters at or around the
- # param_max_rms, because sometimes they can indicate a problem with the
- # topology or settings.
- scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
-
- delta.add_(p * scale_step)
-
- return delta
-
-
-def momentum_step(group, p, state, grad):
- delta = scaling_step(group, p, state, grad)
- beta1 = group["betas"][0]
- try:
- stored_delta = state["delta"]
- except KeyError:
- stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
- state["delta"] = stored_delta
- stored_delta.mul_(beta1)
- stored_delta.add_(delta, alpha=(1 - beta1))
- # we don't bother doing the "bias correction" part of Adam for beta1 because this is just
- # an edge effect that affects the first 10 or so batches; and the effect of not doing it
- # is just to do a slower update for the first few batches, which will help stability.
- return stored_delta
-
-
-class ScaledAdam(BatchedOptimizer):
- """
- Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
- proportional to the norm of that parameter; and also learn the scale of the parameter,
- in log space, subject to upper and lower limits (as if we had factored each parameter as
- param = underlying_param * log_scale.exp())
-
-
- Args:
- params: The parameters or param_groups to optimize (like other Optimizer subclasses)
- Unlike common optimizers, which accept model.parameters() or groups of parameters(),
- this optimizer could accept model.named_parameters() or groups of named_parameters().
- See comments of function _get_names_of_parameters for its 4 possible cases.
- lr: The learning rate. We will typically use a learning rate schedule that starts
- at 0.03 and decreases over time, i.e. much higher than other common
- optimizers.
- clipping_scale: (e.g. 2.0)
- A scale for gradient-clipping: if specified, the normalized gradients
- over the whole model will be clipped to have 2-norm equal to
- `clipping_scale` times the median 2-norm over the most recent period
- of `clipping_update_period` minibatches. By "normalized gradients",
- we mean after multiplying by the rms parameter value for this tensor
- [for non-scalars]; this is appropriate because our update is scaled
- by this quantity.
- betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
- Must satisfy 0 < beta <= beta2 < 1.
- scalar_lr_scale: A scaling factor on the learning rate, that we use to update the scale of each parameter tensor and scalar parameters of the mode..
- If each parameter were decomposed
- as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
- would be a the scaling factor on the learning rate of p_scale.
- eps: A general-purpose epsilon to prevent division by zero
- param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
- learning the scale on the parameters (we'll constrain the rms of each non-scalar
- parameter tensor to be >= this value)
- param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
- learning the scale on the parameters (we'll constrain the rms of each non-scalar
- parameter tensor to be <= this value)
- scalar_max: Maximum absolute value for scalar parameters (applicable if your
- model has any parameters with numel() == 1).
- size_update_period: The periodicity, in steps, with which we update the size (scale)
- of the parameter tensor. This is provided to save a little time
- in the update.
- clipping_update_period: if clipping_scale is specified, this is the period
- """
-
- def __init__(
- self,
- params,
- lr=3e-02,
- clipping_scale=None,
- betas=(0.9, 0.98),
- scalar_lr_scale=0.1,
- eps=1.0e-08,
- param_min_rms=1.0e-05,
- param_max_rms=3.0,
- scalar_max=10.0,
- size_update_period=4,
- clipping_update_period=100,
- ):
-
- defaults = dict(
- lr=lr,
- clipping_scale=clipping_scale,
- betas=betas,
- scalar_lr_scale=scalar_lr_scale,
- eps=eps,
- param_min_rms=param_min_rms,
- param_max_rms=param_max_rms,
- scalar_max=scalar_max,
- size_update_period=size_update_period,
- clipping_update_period=clipping_update_period,
- )
-
- # If params only contains parameters or group of parameters,
- # i.e when parameter names are not given,
- # this flag will be set to False in funciton _get_names_of_parameters.
- self.show_dominant_parameters = True
- param_groups, parameters_names = self._get_names_of_parameters(params)
- super(ScaledAdam, self).__init__(param_groups, defaults)
- assert len(self.param_groups) == len(parameters_names)
- self.parameters_names = parameters_names
-
- def _get_names_of_parameters(
- self, params_or_named_params
- ) -> Tuple[List[Dict], List[List[str]]]:
- """
- Args:
- params_or_named_params: according to the way ScaledAdam is initialized in train.py,
- this argument could be one of following 4 cases,
- case 1, a generator of parameter, e.g.:
- optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0)
-
- case 2, a list of parameter groups with different config, e.g.:
- model_param_groups = [
- {'params': model.encoder.parameters(), 'lr': 0.05},
- {'params': model.decoder.parameters(), 'lr': 0.01},
- {'params': model.joiner.parameters(), 'lr': 0.03},
- ]
- optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0)
-
- case 3, a generator of named_parameter, e.g.:
- optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0)
-
- case 4, a list of named_parameter groups with different config, e.g.:
- model_named_param_groups = [
- {'named_params': model.encoder.named_parameters(), 'lr': 0.05},
- {'named_params': model.decoder.named_parameters(), 'lr': 0.01},
- {'named_params': model.joiner.named_parameters(), 'lr': 0.03},
- ]
- optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0)
-
- For case 1 and case 2, input params is used to initialize the underlying torch.optimizer.
- For case 3 and case 4, firstly, names and params are extracted from input named_params,
- then, these extracted params are used to initialize the underlying torch.optimizer,
- and these extracted names are mainly used by function
- `_show_gradient_dominating_parameter`
-
- Returns:
- Returns a tuple containing 2 elements:
- - `param_groups` with type List[Dict], each Dict element is a parameter group.
- An example of `param_groups` could be:
- [
- {'params': `one iterable of Parameter`, 'lr': 0.05},
- {'params': `another iterable of Parameter`, 'lr': 0.08},
- {'params': `a third iterable of Parameter`, 'lr': 0.1},
- ]
- - `param_gruops_names` with type List[List[str]],
- each `List[str]` is for a group['params'] in param_groups,
- and each `str` is the name of a parameter.
- A dummy name "foo" is related to each parameter,
- if input are params without names, i.e. case 1 or case 2.
- """
- # variable naming convention in this function:
- # p is short for param.
- # np is short for named_param.
- # p_or_np is short for param_or_named_param.
- # cur is short for current.
- # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
- # groups is a List[group]
-
- iterable_or_groups = list(params_or_named_params)
- if len(iterable_or_groups) == 0:
- raise ValueError("optimizer got an empty parameter list")
-
- # The first value of returned tuple. A list of dicts containing at
- # least 'params' as a key.
- param_groups = []
-
- # The second value of returned tuple,
- # a List[List[str]], each sub-List is for a group.
- param_groups_names = []
-
- if not isinstance(iterable_or_groups[0], dict):
- # case 1 or case 3,
- # the input is an iterable of parameter or named parameter.
- param_iterable_cur_group = []
- param_names_cur_group = []
- for p_or_np in iterable_or_groups:
- if isinstance(p_or_np, tuple):
- # case 3
- name, param = p_or_np
- else:
- # case 1
- assert isinstance(p_or_np, torch.Tensor)
- param = p_or_np
- # Assign a dummy name as a placeholder
- name = "foo"
- self.show_dominant_parameters = False
- param_iterable_cur_group.append(param)
- param_names_cur_group.append(name)
- param_groups.append({"params": param_iterable_cur_group})
- param_groups_names.append(param_names_cur_group)
- else:
- # case 2 or case 4
- # the input is groups of parameter or named parameter.
- for cur_group in iterable_or_groups:
- if "named_params" in cur_group:
- name_list = [x[0] for x in cur_group["named_params"]]
- p_list = [x[1] for x in cur_group["named_params"]]
- del cur_group["named_params"]
- cur_group["params"] = p_list
- else:
- assert "params" in cur_group
- name_list = ["foo" for _ in cur_group["params"]]
- param_groups.append(cur_group)
- param_groups_names.append(name_list)
-
- return param_groups, param_groups_names
-
- def __setstate__(self, state):
- super(ScaledAdam, self).__setstate__(state)
-
- @torch.no_grad()
- def step(self, closure=None):
- """Performs a single optimization step.
-
- Arguments:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
-
- batch = True
-
- for group, group_params_names in zip(self.param_groups, self.parameters_names):
-
- with self.batched_params(group["params"], group_params_names) as batches:
-
- # batches is list of pairs (stacked_param, state). stacked_param is like
- # a regular parameter, and will have a .grad, but the 1st dim corresponds to
- # a stacking dim, it is not a real dim.
-
- if (
- len(batches[0][1]) == 0
- ): # if len(first state) == 0: not yet initialized
- clipping_scale = 1
- else:
- clipping_scale = self._get_clipping_scale(group, batches)
-
- for p, state, _ in batches:
- # Perform optimization step.
- # grad is not going to be None, we handled that when creating the batches.
- grad = p.grad
- if grad.is_sparse:
- raise RuntimeError(
- "ScaledAdam optimizer does not support sparse gradients"
- )
-
- try:
- cur_step = state["step"]
- except KeyError:
- state["step"] = 0
- cur_step = 0
-
- grad = (
- p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale)
- )
- p += momentum_step(group, p.detach(), state, grad)
-
- if p.numel() == p.shape[0]: # scalar parameter
- scalar_max = group["scalar_max"]
- p.clamp_(min=-scalar_max, max=scalar_max)
-
- state["step"] = cur_step + 1
-
- return loss
-
- def _get_clipping_scale(
- self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
- ) -> float:
- """
- Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
- by this amount before applying the rest of the update.
-
- Args:
- group: the parameter group, an item in self.param_groups
- tuples: a list of tuples of (param, state, param_names)
- where param is a batched set of parameters,
- with a .grad (1st dim is batch dim)
- and state is the state-dict where optimization parameters are kept.
- param_names is a List[str] while each str is name for a parameter
- in batched set of parameters "param".
- """
- assert len(tuples) >= 1
- clipping_scale = group["clipping_scale"]
- (first_p, first_state, _) = tuples[0]
- step = first_state["step"]
- if clipping_scale is None or step == 0:
- # no clipping. return early on step == 0 because the other
- # parameters' state won't have been initialized yet.
- return 1.0
- clipping_update_period = group["clipping_update_period"]
- scalar_lr_scale = group["scalar_lr_scale"]
-
- tot_sumsq = torch.tensor(0.0, device=first_p.device)
- for (p, state, param_names) in tuples:
- grad = p.grad
- if grad.is_sparse:
- raise RuntimeError(
- "ScaledAdam optimizer does not support sparse gradients"
- )
- if p.numel() == p.shape[0]: # a batch of scalars
- tot_sumsq += (grad**2).sum() * (
- scalar_lr_scale**2
- ) # sum() to change shape [1] to []
- else:
- tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
-
- tot_norm = tot_sumsq.sqrt()
- if "model_norms" not in first_state:
- first_state["model_norms"] = torch.zeros(
- clipping_update_period, device=p.device
- )
- first_state["model_norms"][step % clipping_update_period] = tot_norm
-
- irregular_estimate_steps = [
- i for i in [10, 20, 40] if i < clipping_update_period
- ]
- if step % clipping_update_period == 0 or step in irregular_estimate_steps:
- # Print some stats.
- # We don't reach here if step == 0 because we would have returned
- # above.
- sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
- if step in irregular_estimate_steps:
- sorted_norms = sorted_norms[-step:]
- num_norms = sorted_norms.numel()
- quartiles = []
- for n in range(0, 5):
- index = min(num_norms - 1, (num_norms // 4) * n)
- quartiles.append(sorted_norms[index].item())
-
- median = quartiles[2]
- if median - median != 0:
- raise RuntimeError("Too many grads were not finite")
- threshold = clipping_scale * median
- if step in irregular_estimate_steps:
- # use larger thresholds on first few steps of estimating threshold,
- # as norm may be changing rapidly.
- threshold = threshold * 2.0
- first_state["model_norm_threshold"] = threshold
- percent_clipped = (
- first_state["num_clipped"] * 100.0 / num_norms
- if "num_clipped" in first_state
- else 0.0
- )
- first_state["num_clipped"] = 0
- quartiles = " ".join(["%.3e" % x for x in quartiles])
- logging.warning(
- f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
- f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
- )
-
- try:
- model_norm_threshold = first_state["model_norm_threshold"]
- except KeyError:
- return 1.0 # threshold has not yet been set.
-
- ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
- if ans != ans: # e.g. ans is nan
- ans = 0.0
- if ans < 1.0:
- first_state["num_clipped"] += 1
- if ans < 0.5:
- logging.warning(
- f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
- )
- if self.show_dominant_parameters:
- assert p.shape[0] == len(param_names)
- self._show_gradient_dominating_parameter(
- tuples, tot_sumsq, group["scalar_lr_scale"]
- )
- self._show_param_with_unusual_grad(tuples)
-
- if ans == 0.0:
- for (p, state, param_names) in tuples:
- p.grad.zero_() # get rid of infinity()
-
- return ans
-
- def _show_param_with_unusual_grad(
- self,
- tuples: List[Tuple[Tensor, dict, List[str]]],
- ):
- """
- Print information about parameter which has the largest ratio of grad-on-this-batch
- divided by normal grad size.
- tuples: a list of tuples of (param, state, param_names)
- where param is a batched set of parameters,
- with a .grad (1st dim is batch dim)
- and state is the state-dict where optimization parameters are kept.
- param_names is a List[str] while each str is name for a parameter
- in batched set of parameters "param".
- """
- largest_ratio = 0.0
- largest_name = ""
- # ratios_names is a list of 3-tuples: (grad_ratio, param_name, tensor)
- ratios_names = []
- for (p, state, batch_param_names) in tuples:
- dims = list(range(1, p.ndim))
-
- def mean(x):
- # workaround for bad interface of torch's "mean" for when dims is the empty list.
- if len(dims) > 0:
- return x.mean(dim=dims)
- else:
- return x
-
- grad_ratio = (
- (mean(p.grad**2) / state["exp_avg_sq"].mean(dim=dims))
- .sqrt()
- .to("cpu")
- )
-
- ratios_names += zip(
- grad_ratio.tolist(), batch_param_names, p.grad.unbind(dim=0)
- )
-
- ratios_names = sorted(ratios_names, reverse=True)
- ratios_names = ratios_names[:10]
- ratios_names = [
- (ratio, name, largest_index(tensor))
- for (ratio, name, tensor) in ratios_names
- ]
-
- logging.debug(
- f"Parameters with most larger-than-usual grads, with ratios, are: {ratios_names}"
- )
-
- def _show_gradient_dominating_parameter(
- self,
- tuples: List[Tuple[Tensor, dict, List[str]]],
- tot_sumsq: Tensor,
- scalar_lr_scale: float,
- ):
- """
- Show information of parameter which dominates tot_sumsq.
-
- Args:
- tuples: a list of tuples of (param, state, param_names)
- where param is a batched set of parameters,
- with a .grad (1st dim is batch dim)
- and state is the state-dict where optimization parameters are kept.
- param_names is a List[str] while each str is name for a parameter
- in batched set of parameters "param".
- tot_sumsq: sumsq of all parameters. Though it's could be calculated
- from tuples, we still pass it to save some time.
- """
- all_sumsq_orig = {}
- for (p, state, batch_param_names) in tuples:
- # p is a stacked batch parameters.
- batch_grad = p.grad
- if p.numel() == p.shape[0]: # a batch of scalars
- # Dummy values used by following `zip` statement.
- batch_rms_orig = torch.full(
- p.shape, scalar_lr_scale, device=batch_grad.device
- )
- else:
- batch_rms_orig = state["param_rms"]
- batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
- if batch_grad.ndim > 1:
- # need to guard it with if-statement because sum() sums over
- # all dims if dim == ().
- batch_sumsq_orig = batch_sumsq_orig.sum(
- dim=list(range(1, batch_grad.ndim))
- )
- for name, sumsq_orig, rms, grad in zip(
- batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
- ):
-
- proportion_orig = sumsq_orig / tot_sumsq
- all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
-
- sorted_by_proportion = {
- k: v
- for k, v in sorted(
- all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True
- )
- }
- dominant_param_name = next(iter(sorted_by_proportion))
- (
- dominant_proportion,
- dominant_sumsq,
- dominant_rms,
- dominant_grad,
- ) = sorted_by_proportion[dominant_param_name]
- logging.debug(
- f"Parameter dominating tot_sumsq {dominant_param_name}"
- f" with proportion {dominant_proportion:.2f},"
- f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
- f"={dominant_sumsq:.3e},"
- f" grad_sumsq={(dominant_grad**2).sum():.3e},"
- f" orig_rms_sq={(dominant_rms**2).item():.3e}"
- )
-
-
-def largest_index(x: Tensor):
- x = x.contiguous()
- argmax = x.abs().argmax().item()
- return [(argmax // x.stride(i)) % x.size(i) for i in range(x.ndim)]
-
-
-class LRScheduler(object):
- """
- Base-class for learning rate schedulers where the learning-rate depends on both the
- batch and the epoch.
- """
-
- def __init__(self, optimizer: Optimizer, verbose: bool = False):
- # Attach optimizer
- if not isinstance(optimizer, Optimizer):
- raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
- self.optimizer = optimizer
- self.verbose = verbose
-
- for group in optimizer.param_groups:
- group.setdefault("base_lr", group["lr"])
-
- self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
-
- self.epoch = 0
- self.batch = 0
-
- def state_dict(self):
- """Returns the state of the scheduler as a :class:`dict`.
-
- It contains an entry for every variable in self.__dict__ which
- is not the optimizer.
- """
- return {
- # the user might try to override the base_lr, so don't include this in the state.
- # previously they were included.
- # "base_lrs": self.base_lrs,
- "epoch": self.epoch,
- "batch": self.batch,
- }
-
- def load_state_dict(self, state_dict):
- """Loads the schedulers state.
-
- Args:
- state_dict (dict): scheduler state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- # the things with base_lrs are a work-around for a previous problem
- # where base_lrs were written with the state dict.
- base_lrs = self.base_lrs
- self.__dict__.update(state_dict)
- self.base_lrs = base_lrs
-
- def get_last_lr(self) -> List[float]:
- """Return last computed learning rate by current scheduler. Will be a list of float."""
- return self._last_lr
-
- def get_lr(self):
- # Compute list of learning rates from self.epoch and self.batch and
- # self.base_lrs; this must be overloaded by the user.
- # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
- raise NotImplementedError
-
- def step_batch(self, batch: Optional[int] = None) -> None:
- # Step the batch index, or just set it. If `batch` is specified, it
- # must be the batch index from the start of training, i.e. summed over
- # all epochs.
- # You can call this in any order; if you don't provide 'batch', it should
- # of course be called once per batch.
- if batch is not None:
- self.batch = batch
- else:
- self.batch = self.batch + 1
- self._set_lrs()
-
- def step_epoch(self, epoch: Optional[int] = None):
- # Step the epoch index, or just set it. If you provide the 'epoch' arg,
- # you should call this at the start of the epoch; if you don't provide the 'epoch'
- # arg, you should call it at the end of the epoch.
- if epoch is not None:
- self.epoch = epoch
- else:
- self.epoch = self.epoch + 1
- self._set_lrs()
-
- def _set_lrs(self):
- values = self.get_lr()
- assert len(values) == len(self.optimizer.param_groups)
-
- for i, data in enumerate(zip(self.optimizer.param_groups, values)):
- param_group, lr = data
- param_group["lr"] = lr
- self.print_lr(self.verbose, i, lr)
- self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
-
- def print_lr(self, is_verbose, group, lr):
- """Display the current learning rate."""
- if is_verbose:
- logging.warning(
- f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
- f" of group {group} to {lr:.4e}."
- )
-
-
-class Eden(LRScheduler):
- """
- Eden scheduler.
- The basic formula (before warmup) is:
- lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
- (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
- where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
- and then stays constant at 1.
-
- If you don't have the concept of epochs, or one epoch takes a very long time,
- you can replace the notion of 'epoch' with some measure of the amount of data
- processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to
- some measure representing "quite a lot of data": say, one fifth or one third
- of an entire training run, but it doesn't matter much. You could also use
- Eden2 which has only the notion of batches.
-
- We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
-
- Args:
- optimizer: the optimizer to change the learning rates on
- lr_batches: the number of batches after which we start significantly
- decreasing the learning rate, suggest 5000.
- lr_epochs: the number of epochs after which we start significantly
- decreasing the learning rate, suggest 6 if you plan to do e.g.
- 20 to 40 epochs, but may need smaller number if dataset is huge
- and you will do few epochs.
- """
-
- def __init__(
- self,
- optimizer: Optimizer,
- lr_batches: Union[int, float],
- lr_epochs: Union[int, float],
- warmup_batches: Union[int, float] = 500.0,
- warmup_start: float = 0.5,
- verbose: bool = False,
- ):
- super(Eden, self).__init__(optimizer, verbose)
- self.lr_batches = lr_batches
- self.lr_epochs = lr_epochs
- self.warmup_batches = warmup_batches
-
- assert 0.0 <= warmup_start <= 1.0, warmup_start
- self.warmup_start = warmup_start
-
- def get_lr(self):
- factor = (
- (self.batch**2 + self.lr_batches**2) / self.lr_batches**2
- ) ** -0.25 * (
- ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
- )
- warmup_factor = (
- 1.0
- if self.batch >= self.warmup_batches
- else self.warmup_start
- + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
- # else 0.5 + 0.5 * (self.batch / self.warmup_batches)
- )
-
- return [x * factor * warmup_factor for x in self.base_lrs]
-
-
-class Eden2(LRScheduler):
- """
- Eden2 scheduler, simpler than Eden because it does not use the notion of epoch,
- only batches.
-
- The basic formula (before warmup) is:
- lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup
-
- where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
- and then stays constant at 1.
-
-
- E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
-
- Args:
- optimizer: the optimizer to change the learning rates on
- lr_batches: the number of batches after which we start significantly
- decreasing the learning rate, suggest 5000.
- """
-
- def __init__(
- self,
- optimizer: Optimizer,
- lr_batches: Union[int, float],
- warmup_batches: Union[int, float] = 500.0,
- warmup_start: float = 0.5,
- verbose: bool = False,
- ):
- super().__init__(optimizer, verbose)
- self.lr_batches = lr_batches
- self.warmup_batches = warmup_batches
-
- assert 0.0 <= warmup_start <= 1.0, warmup_start
- self.warmup_start = warmup_start
-
- def get_lr(self):
- factor = (
- (self.batch**2 + self.lr_batches**2) / self.lr_batches**2
- ) ** -0.5
- warmup_factor = (
- 1.0
- if self.batch >= self.warmup_batches
- else self.warmup_start
- + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
- # else 0.5 + 0.5 * (self.batch / self.warmup_batches)
- )
-
- return [x * factor * warmup_factor for x in self.base_lrs]
-
-
-class FixedLRScheduler(LRScheduler):
- """
- Fixed learning rate scheduler.
-
- Args:
- optimizer: the optimizer to change the learning rates on
- """
-
- def __init__(
- self,
- optimizer: Optimizer,
- verbose: bool = False,
- ):
- super(FixedLRScheduler, self).__init__(optimizer, verbose)
-
- def get_lr(self):
-
- return [x for x in self.base_lrs]
-
-
-def _test_eden():
- m = torch.nn.Linear(100, 100)
- optim = ScaledAdam(m.parameters(), lr=0.03)
-
- scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
-
- for epoch in range(10):
- scheduler.step_epoch(epoch) # sets epoch to `epoch`
-
- for step in range(20):
- x = torch.randn(200, 100).detach()
- x.requires_grad = True
- y = m(x)
- dy = torch.randn(200, 100).detach()
- f = (y * dy).sum()
- f.backward()
-
- optim.step()
- scheduler.step_batch()
- optim.zero_grad()
-
- logging.info(f"last lr = {scheduler.get_last_lr()}")
- logging.info(f"state dict = {scheduler.state_dict()}")
-
-
-# This is included mostly as a baseline for ScaledAdam.
-class Eve(Optimizer):
- """
- Implements Eve algorithm. This is a modified version of AdamW with a special
- way of setting the weight-decay / shrinkage-factor, which is designed to make the
- rms of the parameters approach a particular target_rms (default: 0.1). This is
- for use with networks with 'scaled' versions of modules (see scaling.py), which
- will be close to invariant to the absolute scale on the parameter matrix.
-
- The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
- The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
- Eve is unpublished so far.
-
- Arguments:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 3e-4;
- this value means that the weight would decay significantly after
- about 3k minibatches. Is not multiplied by learning rate, but
- is conditional on RMS-value of parameter being > target_rms.
- target_rms (float, optional): target root-mean-square value of
- parameters, if they fall below this we will stop applying weight decay.
-
-
- .. _Adam: A Method for Stochastic Optimization:
- https://arxiv.org/abs/1412.6980
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
- """
-
- def __init__(
- self,
- params,
- lr=1e-3,
- betas=(0.9, 0.98),
- eps=1e-8,
- weight_decay=1e-3,
- target_rms=0.1,
- ):
- if not 0.0 <= lr:
- raise ValueError("Invalid learning rate: {}".format(lr))
- if not 0.0 <= eps:
- raise ValueError("Invalid epsilon value: {}".format(eps))
- if not 0.0 <= betas[0] < 1.0:
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
- if not 0.0 <= betas[1] < 1.0:
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
- if not 0 <= weight_decay <= 0.1:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- if not 0 < target_rms <= 10.0:
- raise ValueError("Invalid target_rms value: {}".format(target_rms))
- defaults = dict(
- lr=lr,
- betas=betas,
- eps=eps,
- weight_decay=weight_decay,
- target_rms=target_rms,
- )
- super(Eve, self).__init__(params, defaults)
-
- def __setstate__(self, state):
- super(Eve, self).__setstate__(state)
-
- @torch.no_grad()
- def step(self, closure=None):
- """Performs a single optimization step.
-
- Arguments:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
-
- for group in self.param_groups:
- for p in group["params"]:
- if p.grad is None:
- continue
-
- # Perform optimization step
- grad = p.grad
- if grad.is_sparse:
- raise RuntimeError("AdamW does not support sparse gradients")
-
- state = self.state[p]
-
- # State initialization
- if len(state) == 0:
- state["step"] = 0
- # Exponential moving average of gradient values
- state["exp_avg"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
- # Exponential moving average of squared gradient values
- state["exp_avg_sq"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
-
- exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
-
- beta1, beta2 = group["betas"]
-
- state["step"] += 1
- bias_correction1 = 1 - beta1 ** state["step"]
- bias_correction2 = 1 - beta2 ** state["step"]
-
- # Decay the first and second moment running average coefficient
- exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
- denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_(
- group["eps"]
- )
-
- step_size = group["lr"] / bias_correction1
- target_rms = group["target_rms"]
- weight_decay = group["weight_decay"]
-
- if p.numel() > 1:
- # avoid applying this weight-decay on "scaling factors"
- # (which are scalar).
- is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5))
- p.mul_(1 - (weight_decay * is_above_target_rms))
-
- p.addcdiv_(exp_avg, denom, value=-step_size)
-
- if random.random() < 0.0005:
- step = (exp_avg / denom) * step_size
- logging.info(
- f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
- )
-
- return loss
-
-
-def _test_scaled_adam(hidden_dim: int):
- import timeit
-
- from scaling import ScaledLinear
-
- E = 100
- B = 4
- T = 2
- logging.info("in test_eve_cain")
- # device = torch.device('cuda')
- device = torch.device("cpu")
- dtype = torch.float32
-
- fix_random_seed(42)
- # these input_magnitudes and output_magnitudes are to test that
- # Abel is working as we expect and is able to adjust scales of
- # different dims differently.
- input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
- output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
-
- for iter in [1, 0]:
- fix_random_seed(42)
- Linear = torch.nn.Linear if iter == 0 else ScaledLinear
-
- m = torch.nn.Sequential(
- Linear(E, hidden_dim),
- torch.nn.PReLU(),
- Linear(hidden_dim, hidden_dim),
- torch.nn.PReLU(),
- Linear(hidden_dim, E),
- ).to(device)
-
- train_pairs = [
- (
- 100.0
- * torch.randn(B, T, E, device=device, dtype=dtype)
- * input_magnitudes,
- torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes,
- )
- for _ in range(20)
- ]
-
- if iter == 0:
- optim = Eve(m.parameters(), lr=0.003)
- elif iter == 1:
- optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0)
- scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
-
- start = timeit.default_timer()
- avg_loss = 0.0
- for epoch in range(180):
- scheduler.step_epoch()
- # if epoch == 100 and iter in [2,3]:
- # optim.reset_speedup() # check it doesn't crash.
-
- # if epoch == 130:
- # opts = diagnostics.TensorDiagnosticOptions(
- # 512
- # ) # allow 4 megabytes per sub-module
- # diagnostic = diagnostics.attach_diagnostics(m, opts)
-
- for n, (x, y) in enumerate(train_pairs):
- y_out = m(x)
- loss = ((y_out - y) ** 2).mean() * 100.0
- if epoch == 0 and n == 0:
- avg_loss = loss.item()
- else:
- avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
- if n == 0 and epoch % 5 == 0:
- # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
- # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
- # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
- # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
- # scale1 = '%.2e' % (m[0].weight_scale.exp().item())
- # scale1b = '%.2e' % (m[0].bias_scale.exp().item())
- # scale2 = '%.2e' % (m[2].weight_scale.exp().item())
- # scale2b = '%.2e' % (m[2].bias_scale.exp().item())
- lr = scheduler.get_last_lr()[0]
- logging.info(
- f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
- ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
- loss.log().backward()
- optim.step()
- optim.zero_grad()
- scheduler.step_batch()
-
- # diagnostic.print_diagnostics()
-
- stop = timeit.default_timer()
- logging.info(f"Iter={iter}, Time taken: {stop - start}")
-
- logging.info(f"last lr = {scheduler.get_last_lr()}")
- # logging.info("state dict = ", scheduler.state_dict())
- # logging.info("optim state_dict = ", optim.state_dict())
- logging.info(f"input_magnitudes = {input_magnitudes}")
- logging.info(f"output_magnitudes = {output_magnitudes}")
-
-
-if __name__ == "__main__":
- torch.set_num_threads(1)
- torch.set_num_interop_threads(1)
- logging.getLogger().setLevel(logging.INFO)
- import subprocess
-
- s = subprocess.check_output(
- "git status -uno .; git log -1; git diff HEAD .", shell=True
- )
- logging.info(s)
- import sys
-
- if len(sys.argv) > 1:
- hidden_dim = int(sys.argv[1])
- else:
- hidden_dim = 200
-
- _test_scaled_adam(hidden_dim)
- _test_eden()
diff --git a/egs/zipvoice/zipvoice/scaling.py b/egs/zipvoice/zipvoice/scaling.py
deleted file mode 100644
index afe9ad468..000000000
--- a/egs/zipvoice/zipvoice/scaling.py
+++ /dev/null
@@ -1,1930 +0,0 @@
-# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import logging
-import math
-import random
-import sys
-from typing import Optional, Tuple, Union
-
-try:
- import k2
-except Exception as ex:
- logging.warning(
- "k2 is not installed correctly. Swoosh functions will fallback to "
- "pytorch implementation."
- )
-
-import torch
-import torch.nn as nn
-from torch import Tensor
-
-custom_bwd = lambda func: torch.amp.custom_bwd(func, device_type="cuda")
-custom_fwd = lambda func: torch.amp.custom_fwd(func, device_type="cuda")
-
-
-def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
- max_value = torch.max(x, y)
- diff = torch.abs(x - y)
- return max_value + torch.log1p(torch.exp(-diff))
-
-
-# RuntimeError: Exporting the operator logaddexp to ONNX opset version
-# 14 is not supported. Please feel free to request support or submit
-# a pull request on PyTorch GitHub.
-#
-# The following function is to solve the above error when exporting
-# models to ONNX via torch.jit.trace()
-def logaddexp(x: Tensor, y: Tensor) -> Tensor:
- # Caution(fangjun): Put torch.jit.is_scripting() before
- # torch.onnx.is_in_onnx_export();
- # otherwise, it will cause errors for torch.jit.script().
- #
- # torch.logaddexp() works for both torch.jit.script() and
- # torch.jit.trace() but it causes errors for ONNX export.
- #
- if torch.jit.is_scripting():
- # Note: We cannot use torch.jit.is_tracing() here as it also
- # matches torch.onnx.export().
- return torch.logaddexp(x, y)
- elif torch.onnx.is_in_onnx_export():
- return logaddexp_onnx(x, y)
- else:
- # for torch.jit.trace()
- return torch.logaddexp(x, y)
-
-
-class PiecewiseLinear(object):
- """
- Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
- the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y]
- respectively.
- """
-
- def __init__(self, *args):
- assert len(args) >= 1, len(args)
- if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
- self.pairs = list(args[0].pairs)
- else:
- self.pairs = [(float(x), float(y)) for x, y in args]
- for x, y in self.pairs:
- assert isinstance(x, (float, int)), type(x)
- assert isinstance(y, (float, int)), type(y)
-
- for i in range(len(self.pairs) - 1):
- assert self.pairs[i + 1][0] > self.pairs[i][0], (
- i,
- self.pairs[i],
- self.pairs[i + 1],
- )
-
- def __str__(self):
- # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
- return f"PiecewiseLinear({str(self.pairs)[1:-1]})"
-
- def __call__(self, x):
- if x <= self.pairs[0][0]:
- return self.pairs[0][1]
- elif x >= self.pairs[-1][0]:
- return self.pairs[-1][1]
- else:
- cur_x, cur_y = self.pairs[0]
- for i in range(1, len(self.pairs)):
- next_x, next_y = self.pairs[i]
- if x >= cur_x and x <= next_x:
- return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
- cur_x, cur_y = next_x, next_y
- assert False
-
- def __mul__(self, alpha):
- return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
-
- def __add__(self, x):
- if isinstance(x, (float, int)):
- return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
- s, x = self.get_common_basis(x)
- return PiecewiseLinear(
- *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]
- )
-
- def max(self, x):
- if isinstance(x, (float, int)):
- x = PiecewiseLinear((0, x))
- s, x = self.get_common_basis(x, include_crossings=True)
- return PiecewiseLinear(
- *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
- )
-
- def min(self, x):
- if isinstance(x, float) or isinstance(x, int):
- x = PiecewiseLinear((0, x))
- s, x = self.get_common_basis(x, include_crossings=True)
- return PiecewiseLinear(
- *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
- )
-
- def __eq__(self, other):
- return self.pairs == other.pairs
-
- def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False):
- """
- Returns (self_mod, p_mod) which are equivalent piecewise linear
- functions to self and p, but with the same x values.
-
- p: the other piecewise linear function
- include_crossings: if true, include in the x values positions
- where the functions indicate by this and p crosss.
- """
- assert isinstance(p, PiecewiseLinear), type(p)
-
- # get sorted x-values without repetition.
- x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
- y_vals1 = [self(x) for x in x_vals]
- y_vals2 = [p(x) for x in x_vals]
-
- if include_crossings:
- extra_x_vals = []
- for i in range(len(x_vals) - 1):
- if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]):
- # if the two lines in this subsegment potentially cross each other..
- diff_cur = abs(y_vals1[i] - y_vals2[i])
- diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
- # `pos`, between 0 and 1, gives the relative x position,
- # with 0 being x_vals[i] and 1 being x_vals[i+1].
- pos = diff_cur / (diff_cur + diff_next)
- extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
- extra_x_vals.append(extra_x_val)
- if len(extra_x_vals) > 0:
- x_vals = sorted(set(x_vals + extra_x_vals))
- y_vals1 = [self(x) for x in x_vals]
- y_vals2 = [p(x) for x in x_vals]
- return (
- PiecewiseLinear(*zip(x_vals, y_vals1)),
- PiecewiseLinear(*zip(x_vals, y_vals2)),
- )
-
-
-class ScheduledFloat(torch.nn.Module):
- """
- This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
- it does not have a working forward() function. You are supposed to cast it to float, as
- in, float(parent_module.whatever), and use it as something like a dropout prob.
-
- It is a floating point value whose value changes depending on the batch count of the
- training loop. It is a piecewise linear function where you specify the (x,y) pairs
- in sorted order on x; x corresponds to the batch index. For batch-index values before the
- first x or after the last x, we just use the first or last y value.
-
- Example:
- self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
-
- `default` is used when self.batch_count is not set or not in training mode or in
- torch.jit scripting mode.
- """
-
- def __init__(self, *args, default: float = 0.0):
- super().__init__()
- # self.batch_count and self.name will be written to in the training loop.
- self.batch_count = None
- self.name = None
- self.default = default
- self.schedule = PiecewiseLinear(*args)
-
- def extra_repr(self) -> str:
- return (
- f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
- )
-
- def __float__(self):
- batch_count = self.batch_count
- if (
- batch_count is None
- or not self.training
- or torch.jit.is_scripting()
- or torch.jit.is_tracing()
- ):
- return float(self.default)
- else:
- ans = self.schedule(self.batch_count)
- if random.random() < 0.0002:
- logging.debug(
- f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}"
- )
- return ans
-
- def __add__(self, x):
- if isinstance(x, float) or isinstance(x, int):
- return ScheduledFloat(self.schedule + x, default=self.default)
- else:
- return ScheduledFloat(
- self.schedule + x.schedule, default=self.default + x.default
- )
-
- def max(self, x):
- if isinstance(x, float) or isinstance(x, int):
- return ScheduledFloat(self.schedule.max(x), default=self.default)
- else:
- return ScheduledFloat(
- self.schedule.max(x.schedule), default=max(self.default, x.default)
- )
-
-
-FloatLike = Union[float, ScheduledFloat]
-
-
-def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
- """
- A randomized way of casting a floating point value to half precision.
- """
- if x.dtype == torch.float16:
- return x
- x_abs = x.abs()
- is_too_small = x_abs < min_abs
- # for elements where is_too_small is true, random_val will contain +-min_abs with
- # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
- # for those elements].
- random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
- return torch.where(is_too_small, random_val, x).to(torch.float16)
-
-
-class CutoffEstimator:
- """
- Estimates cutoffs of an arbitrary numerical quantity such that a specified
- proportion of items will be above the cutoff on average.
-
- p is the proportion of items that should be above the cutoff.
- """
-
- def __init__(self, p: float):
- self.p = p
- # total count of items
- self.count = 0
- # total count of items that were above the cutoff
- self.count_above = 0
- # initial cutoff value
- self.cutoff = 0
-
- def __call__(self, x: float) -> bool:
- """
- Returns true if x is above the cutoff.
- """
- ans = x > self.cutoff
- self.count += 1
- if ans:
- self.count_above += 1
- cur_p = self.count_above / self.count
- delta_p = cur_p - self.p
- if (delta_p > 0) == ans:
- q = abs(delta_p)
- self.cutoff = x * q + self.cutoff * (1 - q)
- return ans
-
-
-class SoftmaxFunction(torch.autograd.Function):
- """
- Tries to handle half-precision derivatives in a randomized way that should
- be more accurate for training than the default behavior.
- """
-
- @staticmethod
- def forward(ctx, x: Tensor, dim: int):
- ans = x.softmax(dim=dim)
- # if x dtype is float16, x.softmax() returns a float32 because
- # (presumably) that op does not support float16, and autocast
- # is enabled.
- if torch.is_autocast_enabled():
- ans = ans.to(torch.float16)
- ctx.save_for_backward(ans)
- ctx.x_dtype = x.dtype
- ctx.dim = dim
- return ans
-
- @staticmethod
- def backward(ctx, ans_grad: Tensor):
- (ans,) = ctx.saved_tensors
- with torch.amp.autocast("cuda", enabled=False):
- ans_grad = ans_grad.to(torch.float32)
- ans = ans.to(torch.float32)
- x_grad = ans_grad * ans
- x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
- return x_grad, None
-
-
-def softmax(x: Tensor, dim: int):
- if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
- return x.softmax(dim=dim)
-
- return SoftmaxFunction.apply(x, dim)
-
-
-class MaxEigLimiterFunction(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- x: Tensor,
- coeffs: Tensor,
- direction: Tensor,
- channel_dim: int,
- grad_scale: float,
- ) -> Tensor:
- ctx.channel_dim = channel_dim
- ctx.grad_scale = grad_scale
- ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
- return x
-
- @staticmethod
- def backward(ctx, x_grad, *args):
- with torch.enable_grad():
- (x_orig, coeffs, new_direction) = ctx.saved_tensors
- x_orig.requires_grad = True
- num_channels = x_orig.shape[ctx.channel_dim]
- x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
- new_direction.requires_grad = False
- x = x - x.mean(dim=0)
- x_var = (x**2).mean()
- x_residual = x - coeffs * new_direction
- x_residual_var = (x_residual**2).mean()
- # `variance_proportion` is the proportion of the variance accounted for
- # by the top eigen-direction. This is to be minimized.
- variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
- variance_proportion.backward()
- x_orig_grad = x_orig.grad
- x_extra_grad = (
- x_orig.grad
- * ctx.grad_scale
- * x_grad.norm()
- / (x_orig_grad.norm() + 1.0e-20)
- )
- return x_grad + x_extra_grad.detach(), None, None, None, None
-
-
-class BiasNormFunction(torch.autograd.Function):
- # This computes:
- # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
- # return x * scales
- # (after unsqueezing the bias), but it does it in a memory-efficient way so that
- # it can just store the returned value (chances are, this will also be needed for
- # some other reason, related to the next operation, so we can save memory).
- @staticmethod
- def forward(
- ctx,
- x: Tensor,
- bias: Tensor,
- log_scale: Tensor,
- channel_dim: int,
- store_output_for_backprop: bool,
- ) -> Tensor:
- assert bias.ndim == 1
- if channel_dim < 0:
- channel_dim = channel_dim + x.ndim
- ctx.store_output_for_backprop = store_output_for_backprop
- ctx.channel_dim = channel_dim
- for _ in range(channel_dim + 1, x.ndim):
- bias = bias.unsqueeze(-1)
- scales = (
- torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
- ) * log_scale.exp()
- ans = x * scales
- ctx.save_for_backward(
- ans.detach() if store_output_for_backprop else x,
- scales.detach(),
- bias.detach(),
- log_scale.detach(),
- )
- return ans
-
- @staticmethod
- def backward(ctx, ans_grad: Tensor) -> Tensor:
- ans_or_x, scales, bias, log_scale = ctx.saved_tensors
- if ctx.store_output_for_backprop:
- x = ans_or_x / scales
- else:
- x = ans_or_x
- x = x.detach()
- x.requires_grad = True
- bias.requires_grad = True
- log_scale.requires_grad = True
- with torch.enable_grad():
- # recompute scales from x, bias and log_scale.
- scales = (
- torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5
- ) * log_scale.exp()
- ans = x * scales
- ans.backward(gradient=ans_grad)
- return x.grad, bias.grad.flatten(), log_scale.grad, None, None
-
-
-class BiasNorm(torch.nn.Module):
- """
- This is intended to be a simpler, and hopefully cheaper, replacement for
- LayerNorm. The observation this is based on, is that Transformer-type
- networks, especially with pre-norm, sometimes seem to set one of the
- feature dimensions to a large constant value (e.g. 50), which "defeats"
- the LayerNorm because the output magnitude is then not strongly dependent
- on the other (useful) features. Presumably the weight and bias of the
- LayerNorm are required to allow it to do this.
-
- Instead, we give the BiasNorm a trainable bias that it can use when
- computing the scale for normalization. We also give it a (scalar)
- trainable scale on the output.
-
-
- Args:
- num_channels: the number of channels, e.g. 512.
- channel_dim: the axis/dimension corresponding to the channel,
- interpreted as an offset from the input's ndim if negative.
- This is NOT the num_channels; it should typically be one of
- {-2, -1, 0, 1, 2, 3}.
- log_scale: the initial log-scale that we multiply the output by; this
- is learnable.
- log_scale_min: FloatLike, minimum allowed value of log_scale
- log_scale_max: FloatLike, maximum allowed value of log_scale
- store_output_for_backprop: only possibly affects memory use; recommend
- to set to True if you think the output of this module is more likely
- than the input of this module to be required to be stored for the
- backprop.
- """
-
- def __init__(
- self,
- num_channels: int,
- channel_dim: int = -1, # CAUTION: see documentation.
- log_scale: float = 1.0,
- log_scale_min: float = -1.5,
- log_scale_max: float = 1.5,
- store_output_for_backprop: bool = False,
- ) -> None:
- super(BiasNorm, self).__init__()
- self.num_channels = num_channels
- self.channel_dim = channel_dim
- self.log_scale = nn.Parameter(torch.tensor(log_scale))
- self.bias = nn.Parameter(torch.zeros(num_channels))
-
- self.log_scale_min = log_scale_min
- self.log_scale_max = log_scale_max
-
- self.store_output_for_backprop = store_output_for_backprop
-
- def forward(self, x: Tensor) -> Tensor:
- assert x.shape[self.channel_dim] == self.num_channels
-
- if torch.jit.is_scripting() or torch.jit.is_tracing():
- channel_dim = self.channel_dim
- if channel_dim < 0:
- channel_dim += x.ndim
- bias = self.bias
- for _ in range(channel_dim + 1, x.ndim):
- bias = bias.unsqueeze(-1)
- scales = (
- torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
- ) * self.log_scale.exp()
- return x * scales
-
- log_scale = limit_param_value(
- self.log_scale,
- min=float(self.log_scale_min),
- max=float(self.log_scale_max),
- training=self.training,
- )
-
- return BiasNormFunction.apply(
- x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop
- )
-
-
-def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
- """
- Behaves like a constructor of a modified version of nn.Linear
- that gives an easy way to set the default initial parameter scale.
-
- Args:
- Accepts the standard args and kwargs that nn.Linear accepts
- e.g. in_features, out_features, bias=False.
-
- initial_scale: you can override this if you want to increase
- or decrease the initial magnitude of the module's output
- (affects the initialization of weight_scale and bias_scale).
- Another option, if you want to do something like this, is
- to re-initialize the parameters.
- """
- ans = nn.Linear(*args, **kwargs)
- with torch.no_grad():
- ans.weight[:] *= initial_scale
- if ans.bias is not None:
- torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
- return ans
-
-
-def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d:
- """
- Behaves like a constructor of a modified version of nn.Conv1d
- that gives an easy way to set the default initial parameter scale.
-
- Args:
- Accepts the standard args and kwargs that nn.Linear accepts
- e.g. in_features, out_features, bias=False.
-
- initial_scale: you can override this if you want to increase
- or decrease the initial magnitude of the module's output
- (affects the initialization of weight_scale and bias_scale).
- Another option, if you want to do something like this, is
- to re-initialize the parameters.
- """
- ans = nn.Conv1d(*args, **kwargs)
- with torch.no_grad():
- ans.weight[:] *= initial_scale
- if ans.bias is not None:
- torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
- return ans
-
-
-def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d:
- """
- Behaves like a constructor of a modified version of nn.Conv2d
- that gives an easy way to set the default initial parameter scale.
-
- Args:
- Accepts the standard args and kwargs that nn.Linear accepts
- e.g. in_features, out_features, bias=False, but:
- NO PADDING-RELATED ARGS.
-
- initial_scale: you can override this if you want to increase
- or decrease the initial magnitude of the module's output
- (affects the initialization of weight_scale and bias_scale).
- Another option, if you want to do something like this, is
- to re-initialize the parameters.
- """
- ans = nn.Conv2d(*args, **kwargs)
- with torch.no_grad():
- ans.weight[:] *= initial_scale
- if ans.bias is not None:
- torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
- return ans
-
-
-class ChunkCausalDepthwiseConv1d(torch.nn.Module):
- """
- Behaves like a depthwise 1d convolution, except that it is causal in
- a chunkwise way, as if we had a block-triangular attention mask.
- The chunk size is provided at test time (it should probably be
- kept in sync with the attention mask).
-
- This has a little more than twice the parameters of a conventional
- depthwise conv1d module: we implement it by having one
- depthwise convolution, of half the width, that is causal (via
- right-padding); and one depthwise convolution that is applied only
- within chunks, that we multiply by a scaling factor which depends
- on the position within the chunk.
-
- Args:
- Accepts the standard args and kwargs that nn.Linear accepts
- e.g. in_features, out_features, bias=False.
-
- initial_scale: you can override this if you want to increase
- or decrease the initial magnitude of the module's output
- (affects the initialization of weight_scale and bias_scale).
- Another option, if you want to do something like this, is
- to re-initialize the parameters.
- """
-
- def __init__(
- self,
- channels: int,
- kernel_size: int,
- initial_scale: float = 1.0,
- bias: bool = True,
- ):
- super().__init__()
- assert kernel_size % 2 == 1
-
- half_kernel_size = (kernel_size + 1) // 2
- # will pad manually, on one side.
- self.causal_conv = nn.Conv1d(
- in_channels=channels,
- out_channels=channels,
- groups=channels,
- kernel_size=half_kernel_size,
- padding=0,
- bias=True,
- )
-
- self.chunkwise_conv = nn.Conv1d(
- in_channels=channels,
- out_channels=channels,
- groups=channels,
- kernel_size=kernel_size,
- padding=kernel_size // 2,
- bias=bias,
- )
-
- # first row is correction factors added to the scale near the left edge of the chunk,
- # second row is correction factors added to the scale near the right edge of the chunk,
- # both of these are added to a default scale of 1.0.
- self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size))
- self.kernel_size = kernel_size
-
- with torch.no_grad():
- self.causal_conv.weight[:] *= initial_scale
- self.chunkwise_conv.weight[:] *= initial_scale
- if bias:
- torch.nn.init.uniform_(
- self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale
- )
-
- def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
- """
- Forward function. Args:
- x: a Tensor of shape (batch_size, channels, seq_len)
- chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
- """
- (batch_size, num_channels, seq_len) = x.shape
-
- # half_kernel_size = self.kernel_size + 1 // 2
- # left_pad is half_kernel_size - 1 where half_kernel_size is the size used
- # in the causal conv. It's the amount by which we must pad on the left,
- # to make the convolution causal.
- left_pad = self.kernel_size // 2
-
- if chunk_size < 0 or chunk_size > seq_len:
- chunk_size = seq_len
- right_pad = -seq_len % chunk_size
-
- x = torch.nn.functional.pad(x, (left_pad, right_pad))
-
- x_causal = self.causal_conv(x[..., : left_pad + seq_len])
- assert x_causal.shape == (batch_size, num_channels, seq_len)
-
- x_chunk = x[..., left_pad:]
- num_chunks = x_chunk.shape[2] // chunk_size
- x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size)
- x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(
- batch_size * num_chunks, num_channels, chunk_size
- )
- x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
-
- chunk_scale = self._get_chunk_scale(chunk_size)
-
- x_chunk = x_chunk * chunk_scale
- x_chunk = x_chunk.reshape(
- batch_size, num_chunks, num_channels, chunk_size
- ).permute(0, 2, 1, 3)
- x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[
- ..., :seq_len
- ]
-
- return x_chunk + x_causal
-
- def _get_chunk_scale(self, chunk_size: int):
- """Returns tensor of shape (num_channels, chunk_size) that will be used to
- scale the output of self.chunkwise_conv."""
- left_edge = self.chunkwise_conv_scale[0]
- right_edge = self.chunkwise_conv_scale[1]
- if chunk_size < self.kernel_size:
- left_edge = left_edge[:, :chunk_size]
- right_edge = right_edge[:, -chunk_size:]
- else:
- t = chunk_size - self.kernel_size
- channels = left_edge.shape[0]
- pad = torch.zeros(
- channels, t, device=left_edge.device, dtype=left_edge.dtype
- )
- left_edge = torch.cat((left_edge, pad), dim=-1)
- right_edge = torch.cat((pad, right_edge), dim=-1)
- return 1.0 + (left_edge + right_edge)
-
- def streaming_forward(
- self,
- x: Tensor,
- cache: Tensor,
- ) -> Tuple[Tensor, Tensor]:
- """Streaming Forward function.
-
- Args:
- x: a Tensor of shape (batch_size, channels, seq_len)
- cache: cached left context of shape (batch_size, channels, left_pad)
- """
- (batch_size, num_channels, seq_len) = x.shape
-
- # left_pad is half_kernel_size - 1 where half_kernel_size is the size used
- # in the causal conv. It's the amount by which we must pad on the left,
- # to make the convolution causal.
- left_pad = self.kernel_size // 2
-
- # Pad cache
- assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad)
- x = torch.cat([cache, x], dim=2)
- # Update cache
- cache = x[..., -left_pad:]
-
- x_causal = self.causal_conv(x)
- assert x_causal.shape == (batch_size, num_channels, seq_len)
-
- x_chunk = x[..., left_pad:]
- x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
-
- chunk_scale = self._get_chunk_scale(chunk_size=seq_len)
- x_chunk = x_chunk * chunk_scale
-
- return x_chunk + x_causal, cache
-
-
-class BalancerFunction(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- x: Tensor,
- min_mean: float,
- max_mean: float,
- min_rms: float,
- max_rms: float,
- grad_scale: float,
- channel_dim: int,
- ) -> Tensor:
- if channel_dim < 0:
- channel_dim += x.ndim
- ctx.channel_dim = channel_dim
- ctx.save_for_backward(x)
- ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim)
- return x
-
- @staticmethod
- def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
- (x,) = ctx.saved_tensors
- (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
-
- try:
- with torch.enable_grad():
- with torch.amp.autocast("cuda", enabled=False):
- x = x.to(torch.float32)
- x = x.detach()
- x.requires_grad = True
- mean_dims = [i for i in range(x.ndim) if i != channel_dim]
- uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True)
- mean = x.mean(dim=mean_dims, keepdim=True)
- stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
- rms = uncentered_var.clamp(min=1.0e-20).sqrt()
-
- m = mean / stddev
- # part of loss that relates to mean / stddev
- m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
-
- # put a much larger scale on the RMS-max-limit loss, so that if both it and the
- # m_loss are violated we fix the RMS loss first.
- rms_clamped = rms.clamp(min=min_rms, max=max_rms)
- r_loss = (rms_clamped / rms).log().abs()
-
- loss = m_loss + r_loss
-
- loss.backward(gradient=torch.ones_like(loss))
- loss_grad = x.grad
- loss_grad_rms = (
- (loss_grad**2)
- .mean(dim=mean_dims, keepdim=True)
- .sqrt()
- .clamp(min=1.0e-20)
- )
-
- loss_grad = loss_grad * (grad_scale / loss_grad_rms)
-
- x_grad_float = x_grad.to(torch.float32)
- # scale each element of loss_grad by the absolute value of the corresponding
- # element of x_grad, which we view as a noisy estimate of its magnitude for that
- # (frame and dimension). later we can consider factored versions.
- x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
- x_grad = x_grad_mod.to(x_grad.dtype)
- except Exception as e:
- logging.info(
- f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue."
- )
-
- return x_grad, None, None, None, None, None, None
-
-
-class Balancer(torch.nn.Module):
- """
- Modifies the backpropped derivatives of a function to try to encourage, for
- each channel, that it is positive at least a proportion `threshold` of the
- time. It does this by multiplying negative derivative values by up to
- (1+max_factor), and positive derivative values by up to (1-max_factor),
- interpolated from 1 at the threshold to those extremal values when none
- of the inputs are positive.
-
- Args:
- num_channels: the number of channels
- channel_dim: the dimension/axis corresponding to the channel, e.g.
- -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
- min_positive: the minimum, per channel, of the proportion of the time
- that (x > 0), below which we start to modify the derivatives.
- max_positive: the maximum, per channel, of the proportion of the time
- that (x > 0), above which we start to modify the derivatives.
- scale_gain_factor: determines the 'gain' with which we increase the
- change in gradient once the constraints on min_abs and max_abs
- are violated.
- min_abs: the minimum average-absolute-value difference from the mean
- value per channel, which we allow, before we start to modify
- the derivatives to prevent this.
- max_abs: the maximum average-absolute-value difference from the mean
- value per channel, which we allow, before we start to modify
- the derivatives to prevent this.
- prob: determines the minimum probability with which we modify the
- gradients for the {min,max}_positive and {min,max}_abs constraints,
- on each forward(). This is done randomly to prevent all layers
- from doing it at the same time.
- """
-
- def __init__(
- self,
- num_channels: int,
- channel_dim: int,
- min_positive: FloatLike = 0.05,
- max_positive: FloatLike = 0.95,
- min_abs: FloatLike = 0.2,
- max_abs: FloatLike = 100.0,
- grad_scale: FloatLike = 0.04,
- prob: Optional[FloatLike] = None,
- ):
- super().__init__()
-
- if prob is None:
- prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
- self.prob = prob
- # 5% of the time we will return and do nothing because memory usage is
- # too high.
- self.mem_cutoff = CutoffEstimator(0.05)
-
- # actually self.num_channels is no longer needed except for an assertion.
- self.num_channels = num_channels
- self.channel_dim = channel_dim
- self.min_positive = min_positive
- self.max_positive = max_positive
- self.min_abs = min_abs
- self.max_abs = max_abs
- self.grad_scale = grad_scale
-
- def forward(self, x: Tensor) -> Tensor:
- if (
- torch.jit.is_scripting()
- or not x.requires_grad
- or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))
- ):
- return _no_op(x)
-
- prob = float(self.prob)
- if random.random() < prob:
- # The following inner-functions convert from the way we historically specified
- # these limitations, as limits on the absolute value and the proportion of positive
- # values, to limits on the RMS value and the (mean / stddev).
- def _abs_to_rms(x):
- # for normally distributed data, if the expected absolute value is x, the
- # expected rms value will be sqrt(pi/2) * x.
- return 1.25331413732 * x
-
- def _proportion_positive_to_mean(x):
- def _atanh(x):
- eps = 1.0e-10
- # eps is to prevent crashes if x is exactly 0 or 1.
- # we'll just end up returning a fairly large value.
- return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0
-
- def _approx_inverse_erf(x):
- # 1 / (sqrt(pi) * ln(2)),
- # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions
- # this approximation is extremely crude and gets progressively worse for
- # x very close to -1 or +1, but we mostly care about the "middle" region
- # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772,
- # and math.erf(0.0407316414078772) = 0.045935330944660666,
- # which is pretty close to 0.05.
- return 0.8139535143 * _atanh(x)
-
- # first convert x from the range 0..1 to the range -1..1 which the error
- # function returns
- x = -1 + (2 * x)
- return _approx_inverse_erf(x)
-
- min_mean = _proportion_positive_to_mean(float(self.min_positive))
- max_mean = _proportion_positive_to_mean(float(self.max_positive))
- min_rms = _abs_to_rms(float(self.min_abs))
- max_rms = _abs_to_rms(float(self.max_abs))
- grad_scale = float(self.grad_scale)
-
- assert x.shape[self.channel_dim] == self.num_channels
-
- return BalancerFunction.apply(
- x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim
- )
- else:
- return _no_op(x)
-
-
-def penalize_abs_values_gt(
- x: Tensor, limit: float, penalty: float, name: str = None
-) -> Tensor:
- """
- Returns x unmodified, but in backprop will put a penalty for the excess of
- the absolute values of elements of x over the limit "limit". E.g. if
- limit == 10.0, then if x has any values over 10 it will get a penalty.
-
- Caution: the value of this penalty will be affected by grad scaling used
- in automatic mixed precision training. For this reasons we use this,
- it shouldn't really matter, or may even be helpful; we just use this
- to disallow really implausible values of scores to be given to softmax.
-
- The name is for randomly printed debug info.
- """
- x_sign = x.sign()
- over_limit = (x.abs() - limit) > 0
- # The following is a memory efficient way to penalize the absolute values of
- # x that's over the limit. (The memory efficiency comes when you think
- # about which items torch needs to cache for the autograd, and which ones it
- # can throw away). The numerical value of aux_loss as computed here will
- # actually be larger than it should be, by limit * over_limit.sum(), but it
- # has the same derivative as the real aux_loss which is penalty * (x.abs() -
- # limit).relu().
- aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
- # note: we don't do sum() here on aux)_loss, but it's as if we had done
- # sum() due to how with_loss() works.
- x = with_loss(x, aux_loss, name)
- # you must use x for something, or this will be ineffective.
- return x
-
-
-def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
- if x.ndim == 2:
- return x.diag()
- else:
- (batch, dim, dim) = x.shape
- x = x.reshape(batch, dim * dim)
- x = x[:, :: dim + 1]
- assert x.shape == (batch, dim)
- return x
-
-
-def _whitening_metric(x: Tensor, num_groups: int):
- """
- Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
- of the centered feature covariance are the same within each group's covariance matrix
- and also between groups.
- Args:
- x: a Tensor of shape (*, num_channels)
- num_groups: the number of groups of channels, a number >=1 that divides num_channels
- Returns:
- Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
- greater than 1.0 otherwise.
- """
- assert x.dtype != torch.float16
- x = x.reshape(-1, x.shape[-1])
- (num_frames, num_channels) = x.shape
- assert num_channels % num_groups == 0
- channels_per_group = num_channels // num_groups
- x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
- # x now has shape (num_groups, num_frames, channels_per_group)
- # subtract the mean so we use the centered, not uncentered, covariance.
- # My experience has been that when we "mess with the gradients" like this,
- # it's better not do anything that tries to move the mean around, because
- # that can easily cause instability.
- x = x - x.mean(dim=1, keepdim=True)
- # x_covar: (num_groups, channels_per_group, channels_per_group)
- x_covar = torch.matmul(x.transpose(1, 2), x)
- x_covar_mean_diag = _diag(x_covar).mean()
- # the following expression is what we'd get if we took the matrix product
- # of each covariance and measured the mean of its trace, i.e.
- # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
- x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
- # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
- metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
- return metric
-
-
-class WhiteningPenaltyFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x: Tensor, module: nn.Module) -> Tensor:
- ctx.save_for_backward(x)
- ctx.module = module
- return x
-
- @staticmethod
- def backward(ctx, x_grad: Tensor):
- (x_orig,) = ctx.saved_tensors
- w = ctx.module
-
- try:
- with torch.enable_grad():
- with torch.amp.autocast("cuda", enabled=False):
- x_detached = x_orig.to(torch.float32).detach()
- x_detached.requires_grad = True
-
- metric = _whitening_metric(x_detached, w.num_groups)
-
- if random.random() < 0.005 or __name__ == "__main__":
- logging.debug(
- f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, "
- f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}"
- )
-
- if metric < float(w.whitening_limit):
- w.prob = w.min_prob
- return x_grad, None
- else:
- w.prob = w.max_prob
- metric.backward()
- penalty_grad = x_detached.grad
- scale = w.grad_scale * (
- x_grad.to(torch.float32).norm()
- / (penalty_grad.norm() + 1.0e-20)
- )
- penalty_grad = penalty_grad * scale
- return x_grad + penalty_grad.to(x_grad.dtype), None
- except Exception as e:
- logging.info(
- f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue."
- )
- return x_grad, None
-
-
-class Whiten(nn.Module):
- def __init__(
- self,
- num_groups: int,
- whitening_limit: FloatLike,
- prob: Union[float, Tuple[float, float]],
- grad_scale: FloatLike,
- ):
- """
- Args:
- num_groups: the number of groups to divide the channel dim into before
- whitening. We will attempt to make the feature covariance
- within each group, after mean subtraction, as "white" as possible,
- while having the same trace across all groups.
- whitening_limit: a value greater than 1.0, that dictates how much
- freedom we have to violate the constraints. 1.0 would mean perfectly
- white, with exactly the same trace across groups; larger values
- give more freedom. E.g. 2.0.
- prob: the probability with which we apply the gradient modification
- (also affects the grad scale). May be supplied as a float,
- or as a pair (min_prob, max_prob)
-
- grad_scale: determines the scale on the gradient term from this object,
- relative to the rest of the gradient on the attention weights.
- E.g. 0.02 (you may want to use smaller values than this if prob is large)
- """
- super(Whiten, self).__init__()
- assert num_groups >= 1
- assert float(whitening_limit) >= 1
- assert grad_scale >= 0
- self.num_groups = num_groups
- self.whitening_limit = whitening_limit
- self.grad_scale = grad_scale
-
- if isinstance(prob, float):
- prob = (prob, prob)
- (self.min_prob, self.max_prob) = prob
- assert 0 < self.min_prob <= self.max_prob <= 1
- self.prob = self.max_prob
- self.name = None # will be set in training loop
-
- def forward(self, x: Tensor) -> Tensor:
- """
- In the forward pass, this function just returns the input unmodified.
- In the backward pass, it will modify the gradients to ensure that the
- distribution in each group has close to (lambda times I) as the covariance
- after mean subtraction, with the same lambda across groups.
- For whitening_limit > 1, there will be more freedom to violate this
- constraint.
-
- Args:
- x: the input of shape (*, num_channels)
-
- Returns:
- x, unmodified. You should make sure
- you use the returned value, or the graph will be freed
- and nothing will happen in backprop.
- """
- grad_scale = float(self.grad_scale)
- if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
- return _no_op(x)
- else:
- return WhiteningPenaltyFunction.apply(x, self)
-
-
-class WithLoss(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x: Tensor, y: Tensor, name: str):
- ctx.y_shape = y.shape
- if random.random() < 0.002 and name is not None:
- loss_sum = y.sum().item()
- logging.debug(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}")
- return x
-
- @staticmethod
- def backward(ctx, ans_grad: Tensor):
- return (
- ans_grad,
- torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
- None,
- )
-
-
-def with_loss(x, y, name):
- # returns x but adds y.sum() to the loss function.
- return WithLoss.apply(x, y, name)
-
-
-class ScaleGradFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x: Tensor, alpha: float) -> Tensor:
- ctx.alpha = alpha
- return x
-
- @staticmethod
- def backward(ctx, grad: Tensor):
- return grad * ctx.alpha, None
-
-
-def scale_grad(x: Tensor, alpha: float):
- return ScaleGradFunction.apply(x, alpha)
-
-
-class ScaleGrad(nn.Module):
- def __init__(self, alpha: float):
- super().__init__()
- self.alpha = alpha
-
- def forward(self, x: Tensor) -> Tensor:
- if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
- return x
- return scale_grad(x, self.alpha)
-
-
-class LimitParamValue(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x: Tensor, min: float, max: float):
- ctx.save_for_backward(x)
- assert max >= min
- ctx.min = min
- ctx.max = max
- return x
-
- @staticmethod
- def backward(ctx, x_grad: Tensor):
- (x,) = ctx.saved_tensors
- # where x < ctx.min, ensure all grads are negative (this will tend to make
- # x more positive).
- x_grad = x_grad * torch.where(
- torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0
- )
- # where x > ctx.max, ensure all grads are positive (this will tend to make
- # x more negative).
- x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
- return x_grad, None, None
-
-
-def limit_param_value(
- x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True
-):
- # You apply this to (typically) an nn.Parameter during training to ensure that its
- # (elements mostly) stays within a supplied range. This is done by modifying the
- # gradients in backprop.
- # It's not necessary to do this on every batch: do it only some of the time,
- # to save a little time.
- if training and random.random() < prob:
- return LimitParamValue.apply(x, min, max)
- else:
- return x
-
-
-def _no_op(x: Tensor) -> Tensor:
- if torch.jit.is_scripting() or torch.jit.is_tracing():
- return x
- else:
- # a no-op function that will have a node in the autograd graph,
- # to avoid certain bugs relating to backward hooks
- return x.chunk(1, dim=-1)[0]
-
-
-class Identity(torch.nn.Module):
- def __init__(self):
- super(Identity, self).__init__()
-
- def forward(self, x):
- return _no_op(x)
-
-
-class DoubleSwishFunction(torch.autograd.Function):
- """
- double_swish(x) = x * torch.sigmoid(x-1)
-
- This is a definition, originally motivated by its close numerical
- similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
-
- Memory-efficient derivative computation:
- double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
- double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
- Now, s'(x) = s(x) * (1-s(x)).
- double_swish'(x) = x * s'(x) + s(x).
- = x * s(x) * (1-s(x)) + s(x).
- = double_swish(x) * (1-s(x)) + s(x)
- ... so we just need to remember s(x) but not x itself.
- """
-
- @staticmethod
- def forward(ctx, x: Tensor) -> Tensor:
- requires_grad = x.requires_grad
- if x.dtype == torch.float16:
- x = x.to(torch.float32)
-
- s = torch.sigmoid(x - 1.0)
- y = x * s
-
- if requires_grad:
- deriv = y * (1 - s) + s
-
- # notes on derivative of x * sigmoid(x - 1):
- # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
- # min \simeq -0.043638. Take floor as -0.044 so it's a lower bund
- # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
- # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
- # floors), should be expectation-preserving.
- floor = -0.044
- ceil = 1.2
- d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
- deriv
- )
- if __name__ == "__main__":
- # for self-testing only.
- assert d_scaled.min() >= 0.0
- assert d_scaled.max() < 256.0
- d_int = d_scaled.to(torch.uint8)
- ctx.save_for_backward(d_int)
- if x.dtype == torch.float16 or torch.is_autocast_enabled():
- y = y.to(torch.float16)
- return y
-
- @staticmethod
- def backward(ctx, y_grad: Tensor) -> Tensor:
- (d,) = ctx.saved_tensors
- # the same constants as used in forward pass.
- floor = -0.043637
- ceil = 1.2
-
- d = d * ((ceil - floor) / 255.0) + floor
- return y_grad * d
-
-
-class DoubleSwish(torch.nn.Module):
- def __init__(self):
- super().__init__()
-
- def forward(self, x: Tensor) -> Tensor:
- """Return double-swish activation function which is an approximation to Swish(Swish(x)),
- that we approximate closely with x * sigmoid(x-1).
- """
- if torch.jit.is_scripting() or torch.jit.is_tracing():
- return x * torch.sigmoid(x - 1.0)
- return DoubleSwishFunction.apply(x)
-
-
-# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates.
-class Dropout2(nn.Module):
- def __init__(self, p: FloatLike):
- super().__init__()
- self.p = p
-
- def forward(self, x: Tensor) -> Tensor:
- return torch.nn.functional.dropout(x, p=float(self.p), training=self.training)
-
-
-class MulForDropout3(torch.autograd.Function):
- # returns (x * y * alpha) where alpha is a float and y doesn't require
- # grad and is zero-or-one.
- @staticmethod
- @custom_fwd
- def forward(ctx, x, y, alpha):
- assert not y.requires_grad
- ans = x * y * alpha
- ctx.save_for_backward(ans)
- ctx.alpha = alpha
- return ans
-
- @staticmethod
- @custom_bwd
- def backward(ctx, ans_grad):
- (ans,) = ctx.saved_tensors
- x_grad = ctx.alpha * ans_grad * (ans != 0)
- return x_grad, None, None
-
-
-# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
-# and it lets you choose one dimension to share the dropout mask over
-class Dropout3(nn.Module):
- def __init__(self, p: FloatLike, shared_dim: int):
- super().__init__()
- self.p = p
- self.shared_dim = shared_dim
-
- def forward(self, x: Tensor) -> Tensor:
- p = float(self.p)
- if not self.training or p == 0:
- return _no_op(x)
- scale = 1.0 / (1 - p)
- rand_shape = list(x.shape)
- rand_shape[self.shared_dim] = 1
- mask = torch.rand(*rand_shape, device=x.device) > p
- ans = MulForDropout3.apply(x, mask, scale)
- return ans
-
-
-class SwooshLFunction(torch.autograd.Function):
- """
- swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
- """
-
- @staticmethod
- def forward(ctx, x: Tensor) -> Tensor:
- requires_grad = x.requires_grad
- if x.dtype == torch.float16:
- x = x.to(torch.float32)
-
- zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
-
- coeff = -0.08
-
- with torch.amp.autocast("cuda", enabled=False):
- with torch.enable_grad():
- x = x.detach()
- x.requires_grad = True
- y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035
-
- if not requires_grad:
- return y
-
- y.backward(gradient=torch.ones_like(y))
-
- grad = x.grad
- floor = coeff
- ceil = 1.0 + coeff + 0.005
-
- d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
- grad
- )
- if __name__ == "__main__":
- # for self-testing only.
- assert d_scaled.min() >= 0.0
- assert d_scaled.max() < 256.0
-
- d_int = d_scaled.to(torch.uint8)
- ctx.save_for_backward(d_int)
- if x.dtype == torch.float16 or torch.is_autocast_enabled():
- y = y.to(torch.float16)
- return y
-
- @staticmethod
- def backward(ctx, y_grad: Tensor) -> Tensor:
- (d,) = ctx.saved_tensors
- # the same constants as used in forward pass.
-
- coeff = -0.08
- floor = coeff
- ceil = 1.0 + coeff + 0.005
- d = d * ((ceil - floor) / 255.0) + floor
- return y_grad * d
-
-
-class SwooshL(torch.nn.Module):
- def forward(self, x: Tensor) -> Tensor:
- """Return Swoosh-L activation."""
- if (
- torch.jit.is_scripting()
- or torch.jit.is_tracing()
- or "k2" not in sys.modules
- ):
- zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
- return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
- if not x.requires_grad:
- return k2.swoosh_l_forward(x)
- else:
- return k2.swoosh_l(x)
- # return SwooshLFunction.apply(x)
-
-
-class SwooshLOnnx(torch.nn.Module):
- def forward(self, x: Tensor) -> Tensor:
- """Return Swoosh-L activation."""
- zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
- return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035
-
-
-class SwooshRFunction(torch.autograd.Function):
- """
- swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
-
- derivatives are between -0.08 and 0.92.
- """
-
- @staticmethod
- def forward(ctx, x: Tensor) -> Tensor:
- requires_grad = x.requires_grad
-
- if x.dtype == torch.float16:
- x = x.to(torch.float32)
-
- zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
-
- with torch.amp.autocast("cuda", enabled=False):
- with torch.enable_grad():
- x = x.detach()
- x.requires_grad = True
- y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
-
- if not requires_grad:
- return y
- y.backward(gradient=torch.ones_like(y))
-
- grad = x.grad
- floor = -0.08
- ceil = 0.925
-
- d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
- grad
- )
- if __name__ == "__main__":
- # for self-testing only.
- assert d_scaled.min() >= 0.0
- assert d_scaled.max() < 256.0
-
- d_int = d_scaled.to(torch.uint8)
- ctx.save_for_backward(d_int)
- if x.dtype == torch.float16 or torch.is_autocast_enabled():
- y = y.to(torch.float16)
- return y
-
- @staticmethod
- def backward(ctx, y_grad: Tensor) -> Tensor:
- (d,) = ctx.saved_tensors
- # the same constants as used in forward pass.
- floor = -0.08
- ceil = 0.925
- d = d * ((ceil - floor) / 255.0) + floor
- return y_grad * d
-
-
-class SwooshR(torch.nn.Module):
- def forward(self, x: Tensor) -> Tensor:
- """Return Swoosh-R activation."""
- if (
- torch.jit.is_scripting()
- or torch.jit.is_tracing()
- or "k2" not in sys.modules
- ):
- zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
- return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
- if not x.requires_grad:
- return k2.swoosh_r_forward(x)
- else:
- return k2.swoosh_r(x)
- # return SwooshRFunction.apply(x)
-
-
-class SwooshROnnx(torch.nn.Module):
- def forward(self, x: Tensor) -> Tensor:
- """Return Swoosh-R activation."""
- zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
- return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687
-
-
-# simple version of SwooshL that does not redefine the backprop, used in
-# ActivationDropoutAndLinearFunction.
-def SwooshLForward(x: Tensor):
- x_offset = x - 4.0
- log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
- log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
- return log_sum - 0.08 * x - 0.035
-
-
-# simple version of SwooshR that does not redefine the backprop, used in
-# ActivationDropoutAndLinearFunction.
-def SwooshRForward(x: Tensor):
- x_offset = x - 1.0
- log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
- log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
- return log_sum - 0.08 * x - 0.313261687
-
-
-class ActivationDropoutAndLinearFunction(torch.autograd.Function):
- @staticmethod
- @custom_fwd
- def forward(
- ctx,
- x: Tensor,
- weight: Tensor,
- bias: Optional[Tensor],
- activation: str,
- dropout_p: float,
- dropout_shared_dim: Optional[int],
- ):
- if dropout_p != 0.0:
- dropout_shape = list(x.shape)
- if dropout_shared_dim is not None:
- dropout_shape[dropout_shared_dim] = 1
- # else it won't be very memory efficient.
- dropout_mask = (1.0 / (1.0 - dropout_p)) * (
- torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p
- )
- else:
- dropout_mask = None
-
- ctx.save_for_backward(x, weight, bias, dropout_mask)
-
- ctx.activation = activation
-
- forward_activation_dict = {
- "SwooshL": k2.swoosh_l_forward,
- "SwooshR": k2.swoosh_r_forward,
- }
- # it will raise a KeyError if this fails. This will be an error. We let it
- # propagate to the user.
- activation_func = forward_activation_dict[activation]
- x = activation_func(x)
- if dropout_mask is not None:
- x = x * dropout_mask
- x = torch.nn.functional.linear(x, weight, bias)
- return x
-
- @staticmethod
- @custom_bwd
- def backward(ctx, ans_grad: Tensor):
- saved = ctx.saved_tensors
- (x, weight, bias, dropout_mask) = saved
-
- forward_and_deriv_activation_dict = {
- "SwooshL": k2.swoosh_l_forward_and_deriv,
- "SwooshR": k2.swoosh_r_forward_and_deriv,
- }
- # the following lines a KeyError if the activation is unrecognized.
- # This will be an error. We let it propagate to the user.
- func = forward_and_deriv_activation_dict[ctx.activation]
-
- y, func_deriv = func(x)
- if dropout_mask is not None:
- y = y * dropout_mask
- # now compute derivative of y w.r.t. weight and bias..
- # y: (..., in_channels), ans_grad: (..., out_channels),
- (out_channels, in_channels) = weight.shape
-
- in_channels = y.shape[-1]
- g = ans_grad.reshape(-1, out_channels)
- weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels))
- y_deriv = torch.matmul(ans_grad, weight)
- bias_deriv = None if bias is None else g.sum(dim=0)
- x_deriv = y_deriv * func_deriv
- if dropout_mask is not None:
- # order versus func_deriv does not matter
- x_deriv = x_deriv * dropout_mask
-
- return x_deriv, weight_deriv, bias_deriv, None, None, None
-
-
-class ActivationDropoutAndLinear(torch.nn.Module):
- """
- This merges an activation function followed by dropout and then a nn.Linear module;
- it does so in a memory efficient way so that it only stores the input to the whole
- module. If activation == SwooshL and dropout_shared_dim != None, this will be
- equivalent to:
- nn.Sequential(SwooshL(),
- Dropout3(dropout_p, shared_dim=dropout_shared_dim),
- ScaledLinear(in_channels, out_channels, bias=bias,
- initial_scale=initial_scale))
- If dropout_shared_dim is None, the dropout would be equivalent to
- Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout
- mask is smaller.
-
- Args:
- in_channels: number of input channels, e.g. 256
- out_channels: number of output channels, e.g. 256
- bias: if true, have a bias
- activation: the activation function, for now just support SwooshL.
- dropout_p: the dropout probability or schedule (happens after nonlinearity).
- dropout_shared_dim: the dimension, if any, across which the dropout mask is
- shared (e.g. the time dimension). If None, this may be less memory
- efficient if there are modules before this one that cache the input
- for their backprop (e.g. Balancer or Whiten).
- """
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- bias: bool = True,
- activation: str = "SwooshL",
- dropout_p: FloatLike = 0.0,
- dropout_shared_dim: Optional[int] = -1,
- initial_scale: float = 1.0,
- ):
- super().__init__()
- # create a temporary module of nn.Linear that we'll steal the
- # weights and bias from
- l = ScaledLinear(
- in_channels, out_channels, bias=bias, initial_scale=initial_scale
- )
-
- self.weight = l.weight
- # register_parameter properly handles making it a parameter when l.bias
- # is None. I think there is some reason for doing it this way rather
- # than just setting it to None but I don't know what it is, maybe
- # something to do with exporting the module..
- self.register_parameter("bias", l.bias)
-
- self.activation = activation
- self.dropout_p = dropout_p
- self.dropout_shared_dim = dropout_shared_dim
-
- def forward(self, x: Tensor):
- if (
- torch.jit.is_scripting()
- or torch.jit.is_tracing()
- or "k2" not in sys.modules
- ):
- if self.activation == "SwooshL":
- x = SwooshLForward(x)
- elif self.activation == "SwooshR":
- x = SwooshRForward(x)
- else:
- assert False, self.activation
- return torch.nn.functional.linear(x, self.weight, self.bias)
-
- return ActivationDropoutAndLinearFunction.apply(
- x,
- self.weight,
- self.bias,
- self.activation,
- float(self.dropout_p),
- self.dropout_shared_dim,
- )
-
-
-def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
- if num_channels <= x.shape[-1]:
- return x[..., :num_channels]
- else:
- shape = list(x.shape)
- shape[-1] = num_channels - shape[-1]
- zeros = torch.zeros(shape, dtype=x.dtype, device=x.device)
- return torch.cat((x, zeros), dim=-1)
-
-
-def _test_whiten():
- for proportion in [0.1, 0.5, 10.0]:
- logging.info(f"_test_whiten(): proportion = {proportion}")
- x = torch.randn(100, 128)
- direction = torch.randn(128)
- coeffs = torch.randn(100, 1)
- x += proportion * direction * coeffs
-
- x.requires_grad = True
-
- m = Whiten(
- 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
- ) # grad_scale
-
- for _ in range(4):
- y = m(x)
-
- y_grad = torch.randn_like(x)
- y.backward(gradient=y_grad)
-
- if proportion < 0.2:
- assert torch.allclose(x.grad, y_grad)
- elif proportion > 1.0:
- assert not torch.allclose(x.grad, y_grad)
-
-
-def _test_balancer_sign():
- probs = torch.arange(0, 1, 0.01)
- N = 1000
- x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
- x = x.detach()
- x.requires_grad = True
- m = Balancer(
- probs.numel(),
- channel_dim=0,
- min_positive=0.05,
- max_positive=0.95,
- min_abs=0.0,
- prob=1.0,
- )
-
- y_grad = torch.sign(torch.randn(probs.numel(), N))
-
- y = m(x)
- y.backward(gradient=y_grad)
- print("_test_balancer_sign: x = ", x)
- print("_test_balancer_sign: y grad = ", y_grad)
- print("_test_balancer_sign: x grad = ", x.grad)
-
-
-def _test_balancer_magnitude():
- magnitudes = torch.arange(0, 1, 0.01)
- N = 1000
- x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
- x = x.detach()
- x.requires_grad = True
- m = Balancer(
- magnitudes.numel(),
- channel_dim=0,
- min_positive=0.0,
- max_positive=1.0,
- min_abs=0.2,
- max_abs=0.7,
- prob=1.0,
- )
-
- y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
-
- y = m(x)
- y.backward(gradient=y_grad)
- print("_test_balancer_magnitude: x = ", x)
- print("_test_balancer_magnitude: y grad = ", y_grad)
- print("_test_balancer_magnitude: x grad = ", x.grad)
-
-
-def _test_double_swish_deriv():
- x = torch.randn(10, 12, dtype=torch.double) * 3.0
- x.requires_grad = True
- m = DoubleSwish()
-
- tol = (1.2 - (-0.043637)) / 255.0
- torch.autograd.gradcheck(m, x, atol=tol)
-
- # for self-test.
- x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
- x.requires_grad = True
- y = m(x)
-
-
-def _test_swooshl_deriv():
- x = torch.randn(10, 12, dtype=torch.double) * 3.0
- x.requires_grad = True
- m = SwooshL()
-
- tol = 1.0 / 255.0
- torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
-
- # for self-test.
- x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
- x.requires_grad = True
- y = m(x)
-
-
-def _test_swooshr_deriv():
- x = torch.randn(10, 12, dtype=torch.double) * 3.0
- x.requires_grad = True
- m = SwooshR()
-
- tol = 1.0 / 255.0
- torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
-
- # for self-test.
- x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
- x.requires_grad = True
- y = m(x)
-
-
-def _test_softmax():
- a = torch.randn(2, 10, dtype=torch.float64)
- b = a.clone()
- a.requires_grad = True
- b.requires_grad = True
- a.softmax(dim=1)[:, 0].sum().backward()
- print("a grad = ", a.grad)
- softmax(b, dim=1)[:, 0].sum().backward()
- print("b grad = ", b.grad)
- assert torch.allclose(a.grad, b.grad)
-
-
-def _test_piecewise_linear():
- p = PiecewiseLinear((0, 10.0))
- for x in [-100, 0, 100]:
- assert p(x) == 10.0
- p = PiecewiseLinear((0, 10.0), (1, 0.0))
- for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]:
- print("x, y = ", x, y)
- assert p(x) == y, (x, p(x), y)
-
- q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
- x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0]
- pq = p.max(q)
- for x in x_vals:
- y1 = max(p(x), q(x))
- y2 = pq(x)
- assert abs(y1 - y2) < 0.001
- pq = p.min(q)
- for x in x_vals:
- y1 = min(p(x), q(x))
- y2 = pq(x)
- assert abs(y1 - y2) < 0.001
- pq = p + q
- for x in x_vals:
- y1 = p(x) + q(x)
- y2 = pq(x)
- assert abs(y1 - y2) < 0.001
-
-
-def _test_activation_dropout_and_linear():
- in_channels = 20
- out_channels = 30
-
- for bias in [True, False]:
- # actually we don't test for dropout_p != 0.0 because forward functions will give
- # different answers. This is because we are using the k2 implementation of
- # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn()
- # internally, messing up the random state.
- for dropout_p in [0.0]:
- for activation in ["SwooshL", "SwooshR"]:
- m1 = nn.Sequential(
- SwooshL() if activation == "SwooshL" else SwooshR(),
- Dropout3(p=dropout_p, shared_dim=-1),
- ScaledLinear(
- in_channels, out_channels, bias=bias, initial_scale=0.5
- ),
- )
- m2 = ActivationDropoutAndLinear(
- in_channels,
- out_channels,
- bias=bias,
- initial_scale=0.5,
- activation=activation,
- dropout_p=dropout_p,
- )
- with torch.no_grad():
- m2.weight[:] = m1[2].weight
- if bias:
- m2.bias[:] = m1[2].bias
- # make sure forward gives same result.
- x1 = torch.randn(10, in_channels)
- x1.requires_grad = True
-
- # TEMP.
- assert torch.allclose(
- SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03
- )
-
- x2 = x1.clone().detach()
- x2.requires_grad = True
- seed = 10
- torch.manual_seed(seed)
- y1 = m1(x1)
- y_grad = torch.randn_like(y1)
- y1.backward(gradient=y_grad)
- torch.manual_seed(seed)
- y2 = m2(x2)
- y2.backward(gradient=y_grad)
-
- print(
- f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}"
- )
- print("y1 = ", y1)
- print("y2 = ", y2)
- assert torch.allclose(y1, y2, atol=0.02)
- assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05)
- if bias:
- assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05)
- print("x1.grad = ", x1.grad)
- print("x2.grad = ", x2.grad)
-
- def isclose(a, b):
- # return true if cosine similarity is > 0.9.
- return (a * b).sum() > 0.9 * (
- (a**2).sum() * (b**2).sum()
- ).sqrt()
-
- # the SwooshL() implementation has a noisy gradient due to 1-byte
- # storage of it.
- assert isclose(x1.grad, x2.grad)
-
-
-if __name__ == "__main__":
- logging.getLogger().setLevel(logging.DEBUG)
- torch.set_num_threads(1)
- torch.set_num_interop_threads(1)
- _test_piecewise_linear()
- _test_softmax()
- _test_whiten()
- _test_balancer_sign()
- _test_balancer_magnitude()
- _test_double_swish_deriv()
- _test_swooshr_deriv()
- _test_swooshl_deriv()
- _test_activation_dropout_and_linear()
diff --git a/egs/zipvoice/zipvoice/solver.py b/egs/zipvoice/zipvoice/solver.py
deleted file mode 100644
index a1e316ec8..000000000
--- a/egs/zipvoice/zipvoice/solver.py
+++ /dev/null
@@ -1,277 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Optional, Union
-
-import torch
-
-
-class DiffusionModel(torch.nn.Module):
- """A wrapper of diffusion models for inference.
- Args:
- model: The diffusion model.
- distill: Whether it is a distillation model.
- """
-
- def __init__(
- self,
- model: torch.nn.Module,
- distill: bool = False,
- func_name: str = "forward_fm_decoder",
- ):
- super().__init__()
- self.model = model
- self.distill = distill
- self.func_name = func_name
- self.model_func = getattr(self.model, func_name)
-
- def forward(
- self,
- t: torch.Tensor,
- x: torch.Tensor,
- text_condition: torch.Tensor,
- speech_condition: torch.Tensor,
- padding_mask: Optional[torch.Tensor] = None,
- guidance_scale: Union[float, torch.Tensor] = 0.0,
- **kwargs
- ) -> torch.Tensor:
- """
- Forward function that Handles the classifier-free guidance.
- Args:
- t: The current timestep, a tensor of shape (batch, 1, 1) or a tensor of a single float.
- x: The initial value, with the shape (batch, seq_len, emb_dim).
- text_condition: The text_condition of the diffision model, with the shape (batch, seq_len, emb_dim).
- speech_condition: The speech_condition of the diffision model, with the shape (batch, seq_len, emb_dim).
- padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len).
- guidance_scale: The scale of classifier-free guidance, a float or a tensor of shape (batch, 1, 1).
- Retrun:
- The prediction with the shape (batch, seq_len, emb_dim).
- """
- if not torch.is_tensor(guidance_scale):
- guidance_scale = torch.tensor(
- guidance_scale, dtype=t.dtype, device=t.device
- )
- if self.distill:
- return self.model_func(
- t=t,
- xt=x,
- text_condition=text_condition,
- speech_condition=speech_condition,
- padding_mask=padding_mask,
- guidance_scale=guidance_scale,
- **kwargs
- )
-
- if (guidance_scale == 0.0).all():
- return self.model_func(
- t=t,
- xt=x,
- text_condition=text_condition,
- speech_condition=speech_condition,
- padding_mask=padding_mask,
- **kwargs
- )
- else:
- if t.dim() != 0:
- t = torch.cat([t] * 2, dim=0)
-
- x = torch.cat([x] * 2, dim=0)
- padding_mask = torch.cat([padding_mask] * 2, dim=0)
-
- text_condition = torch.cat(
- [torch.zeros_like(text_condition), text_condition], dim=0
- )
-
- if t.dim() == 0:
- if t > 0.5:
- speech_condition = torch.cat(
- [torch.zeros_like(speech_condition), speech_condition], dim=0
- )
- else:
- guidance_scale = guidance_scale * 2
- speech_condition = torch.cat(
- [speech_condition, speech_condition], dim=0
- )
- else:
- assert t.dim() > 0, t
- larger_t_index = (t > 0.5).squeeze(1).squeeze(1)
- zero_speech_condition = torch.cat(
- [torch.zeros_like(speech_condition), speech_condition], dim=0
- )
- speech_condition = torch.cat(
- [speech_condition, speech_condition], dim=0
- )
- speech_condition[larger_t_index] = zero_speech_condition[larger_t_index]
- guidance_scale[~larger_t_index[: larger_t_index.size(0) // 2]] *= 2
-
- data_uncond, data_cond = self.model_func(
- t=t,
- xt=x,
- text_condition=text_condition,
- speech_condition=speech_condition,
- padding_mask=padding_mask,
- **kwargs
- ).chunk(2, dim=0)
-
- res = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond
- return res
-
-
-class EulerSolver:
- def __init__(
- self,
- model: torch.nn.Module,
- distill: bool = False,
- func_name: str = "forward_fm_decoder",
- ):
- """Construct a Euler Solver
- Args:
- model: The diffusion model.
- distill: Whether it is distillation model.
- """
-
- self.model = DiffusionModel(model, distill=distill, func_name=func_name)
-
- def sample(
- self,
- x: torch.Tensor,
- text_condition: torch.Tensor,
- speech_condition: torch.Tensor,
- padding_mask: torch.Tensor,
- num_step: int = 10,
- guidance_scale: Union[float, torch.Tensor] = 0.0,
- t_start: Union[float, torch.Tensor] = 0.0,
- t_end: Union[float, torch.Tensor] = 1.0,
- t_shift: float = 1.0,
- **kwargs
- ) -> torch.Tensor:
- """
- Compute the sample at time `t_end` by Euler Solver.
- Args:
- x: The initial value at time `t_start`, with the shape (batch, seq_len, emb_dim).
- text_condition: The text condition of the diffision mode, with the shape (batch, seq_len, emb_dim).
- speech_condition: The speech condition of the diffision model, with the shape (batch, seq_len, emb_dim).
- padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len).
- num_step: The number of ODE steps.
- guidance_scale: The scale for classifier-free guidance, which is
- a float or a tensor with the shape (batch, 1, 1).
- t_start: the start timestep in the range of [0, 1],
- which is a float or a tensor with the shape (batch, 1, 1).
- t_end: the end time_step in the range of [0, 1],
- which is a float or a tensor with the shape (batch, 1, 1).
- t_shift: shift the t toward smaller numbers so that the sampling
- will emphasize low SNR region. Should be in the range of (0, 1].
- The shifting will be more significant when the number is smaller.
-
- Returns:
- The approximated solution at time `t_end`.
- """
- device = x.device
-
- if torch.is_tensor(t_start) and t_start.dim() > 0:
- timesteps = get_time_steps_batch(
- t_start=t_start,
- t_end=t_end,
- num_step=num_step,
- t_shift=t_shift,
- device=device,
- )
- else:
- timesteps = get_time_steps(
- t_start=t_start,
- t_end=t_end,
- num_step=num_step,
- t_shift=t_shift,
- device=device,
- )
- for step in range(num_step):
- v = self.model(
- t=timesteps[step],
- x=x,
- text_condition=text_condition,
- speech_condition=speech_condition,
- padding_mask=padding_mask,
- guidance_scale=guidance_scale,
- **kwargs
- )
- x = x + v * (timesteps[step + 1] - timesteps[step])
- return x
-
-
-def get_time_steps(
- t_start: float = 0.0,
- t_end: float = 1.0,
- num_step: int = 10,
- t_shift: float = 1.0,
- device: torch.device = torch.device("cpu"),
-) -> torch.Tensor:
- """Compute the intermediate time steps for sampling.
-
- Args:
- t_start: The starting time of the sampling (default is 0).
- t_end: The starting time of the sampling (default is 1).
- num_step: The number of sampling.
- t_shift: shift the t toward smaller numbers so that the sampling
- will emphasize low SNR region. Should be in the range of (0, 1].
- The shifting will be more significant when the number is smaller.
- device: A torch device.
- Returns:
- The time step with the shape (num_step + 1,).
- """
-
- timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device)
-
- timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
-
- return timesteps
-
-
-def get_time_steps_batch(
- t_start: torch.Tensor,
- t_end: torch.Tensor,
- num_step: int = 10,
- t_shift: float = 1.0,
- device: torch.device = torch.device("cpu"),
-) -> torch.Tensor:
- """Compute the intermediate time steps for sampling in the batch mode.
-
- Args:
- t_start: The starting time of the sampling (default is 0), with the shape (batch, 1, 1).
- t_end: The starting time of the sampling (default is 1), with the shape (batch, 1, 1).
- num_step: The number of sampling.
- t_shift: shift the t toward smaller numbers so that the sampling
- will emphasize low SNR region. Should be in the range of (0, 1].
- The shifting will be more significant when the number is smaller.
- device: A torch device.
- Returns:
- The time step with the shape (num_step + 1, N, 1, 1).
- """
- while t_start.dim() > 1 and t_start.size(-1) == 1:
- t_start = t_start.squeeze(-1)
- while t_end.dim() > 1 and t_end.size(-1) == 1:
- t_end = t_end.squeeze(-1)
- assert t_start.dim() == t_end.dim() == 1
-
- timesteps_shape = (num_step + 1, t_start.size(0))
- timesteps = torch.zeros(timesteps_shape, device=device)
-
- for i in range(t_start.size(0)):
- timesteps[:, i] = torch.linspace(t_start[i], t_end[i], steps=num_step + 1)
-
- timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
-
- return timesteps.unsqueeze(-1).unsqueeze(-1)
diff --git a/egs/zipvoice/zipvoice/tokenizer.py b/egs/zipvoice/zipvoice/tokenizer.py
deleted file mode 100644
index ea25d3498..000000000
--- a/egs/zipvoice/zipvoice/tokenizer.py
+++ /dev/null
@@ -1,572 +0,0 @@
-# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao
-# Han Zhu,
-# Wei Kang)
-#
-# See ../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import re
-import unicodedata
-from functools import reduce
-from typing import Dict, List, Optional
-
-import cn2an
-import inflect
-import jieba
-from pypinyin import Style, lazy_pinyin
-from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
-
-try:
- from piper_phonemize import phonemize_espeak
-except Exception as ex:
- raise RuntimeError(
- f"{ex}\nPlease run\n"
- "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html"
- )
-
-_inflect = inflect.engine()
-_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
-_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
-_percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
-_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
-_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
-_fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
-_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
-_number_re = re.compile(r"[0-9]+")
-
-# List of (regular expression, replacement) pairs for abbreviations:
-_abbreviations = [
- (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
- for x in [
- ("mrs", "misess"),
- ("mr", "mister"),
- ("dr", "doctor"),
- ("st", "saint"),
- ("co", "company"),
- ("jr", "junior"),
- ("maj", "major"),
- ("gen", "general"),
- ("drs", "doctors"),
- ("rev", "reverend"),
- ("lt", "lieutenant"),
- ("hon", "honorable"),
- ("sgt", "sergeant"),
- ("capt", "captain"),
- ("esq", "esquire"),
- ("ltd", "limited"),
- ("col", "colonel"),
- ("ft", "fort"),
- ("etc", "et cetera"),
- ("btw", "by the way"),
- ]
-]
-
-
-def intersperse(sequence, item=0):
- result = [item] * (len(sequence) * 2 + 1)
- result[1::2] = sequence
- return result
-
-
-def expand_abbreviations(text):
- for regex, replacement in _abbreviations:
- text = re.sub(regex, replacement, text)
- return text
-
-
-def _remove_commas(m):
- return m.group(1).replace(",", "")
-
-
-def _expand_decimal_point(m):
- return m.group(1).replace(".", " point ")
-
-
-def _expand_percent(m):
- return m.group(1).replace("%", " percent ")
-
-
-def _expand_dollars(m):
- match = m.group(1)
- parts = match.split(".")
- if len(parts) > 2:
- return " " + match + " dollars " # Unexpected format
- dollars = int(parts[0]) if parts[0] else 0
- cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
- if dollars and cents:
- dollar_unit = "dollar" if dollars == 1 else "dollars"
- cent_unit = "cent" if cents == 1 else "cents"
- return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
- elif dollars:
- dollar_unit = "dollar" if dollars == 1 else "dollars"
- return " %s %s " % (dollars, dollar_unit)
- elif cents:
- cent_unit = "cent" if cents == 1 else "cents"
- return " %s %s " % (cents, cent_unit)
- else:
- return " zero dollars "
-
-
-def fraction_to_words(numerator, denominator):
- if numerator == 1 and denominator == 2:
- return " one half "
- if numerator == 1 and denominator == 4:
- return " one quarter "
- if denominator == 2:
- return " " + _inflect.number_to_words(numerator) + " halves "
- if denominator == 4:
- return " " + _inflect.number_to_words(numerator) + " quarters "
- return (
- " "
- + _inflect.number_to_words(numerator)
- + " "
- + _inflect.ordinal(_inflect.number_to_words(denominator))
- + " "
- )
-
-
-def _expand_fraction(m):
- numerator = int(m.group(1))
- denominator = int(m.group(2))
- return fraction_to_words(numerator, denominator)
-
-
-def _expand_ordinal(m):
- return " " + _inflect.number_to_words(m.group(0)) + " "
-
-
-def _expand_number(m):
- num = int(m.group(0))
- if num > 1000 and num < 3000:
- if num == 2000:
- return " two thousand "
- elif num > 2000 and num < 2010:
- return " two thousand " + _inflect.number_to_words(num % 100) + " "
- elif num % 100 == 0:
- return " " + _inflect.number_to_words(num // 100) + " hundred "
- else:
- return (
- " "
- + _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(
- ", ", " "
- )
- + " "
- )
- else:
- return " " + _inflect.number_to_words(num, andword="") + " "
-
-
-# Normalize numbers pronunciation
-def normalize_numbers(text):
- text = re.sub(_comma_number_re, _remove_commas, text)
- text = re.sub(_pounds_re, r"\1 pounds", text)
- text = re.sub(_dollars_re, _expand_dollars, text)
- text = re.sub(_fraction_re, _expand_fraction, text)
- text = re.sub(_decimal_number_re, _expand_decimal_point, text)
- text = re.sub(_percent_number_re, _expand_percent, text)
- text = re.sub(_ordinal_re, _expand_ordinal, text)
- text = re.sub(_number_re, _expand_number, text)
- return text
-
-
-# Convert numbers to Chinese pronunciation
-def number_to_chinese(text):
- text = cn2an.transform(text, "an2cn")
- return text
-
-
-def map_punctuations(text):
- text = text.replace(",", ",")
- text = text.replace("。", ".")
- text = text.replace("!", "!")
- text = text.replace("?", "?")
- text = text.replace(";", ";")
- text = text.replace(":", ":")
- text = text.replace("、", ",")
- text = text.replace("‘", "'")
- text = text.replace("“", '"')
- text = text.replace("”", '"')
- text = text.replace("’", "'")
- text = text.replace("⋯", "…")
- text = text.replace("···", "…")
- text = text.replace("・・・", "…")
- text = text.replace("...", "…")
- return text
-
-
-def is_chinese(char):
- if char >= "\u4e00" and char <= "\u9fa5":
- return True
- else:
- return False
-
-
-def is_alphabet(char):
- if (char >= "\u0041" and char <= "\u005a") or (
- char >= "\u0061" and char <= "\u007a"
- ):
- return True
- else:
- return False
-
-
-def is_hangul(char):
- letters = unicodedata.normalize("NFD", char)
- return all(
- ["\u1100" <= c <= "\u11ff" or "\u3131" <= c <= "\u318e" for c in letters]
- )
-
-
-def is_japanese(char):
- return any(
- [
- start <= char <= end
- for start, end in [
- ("\u3041", "\u3096"),
- ("\u30a0", "\u30ff"),
- ("\uff5f", "\uff9f"),
- ("\u31f0", "\u31ff"),
- ("\u3220", "\u3243"),
- ("\u3280", "\u337f"),
- ]
- ]
- )
-
-
-def get_segment(text: str) -> List[str]:
- # sentence --> [ch_part, en_part, ch_part, ...]
- # example :
- # input : 我们是小米人,是吗? Yes I think so!霍...啦啦啦
- # output : [('我们是小米人,是吗? ', 'zh'), ('Yes I think so!', 'en'), ('霍...啦啦啦', 'zh')]
- segments = []
- types = []
- flag = 0
- temp_seg = ""
- temp_lang = ""
-
- for i, ch in enumerate(text):
- if is_chinese(ch):
- types.append("zh")
- elif is_alphabet(ch):
- types.append("en")
- else:
- types.append("other")
-
- assert len(types) == len(text)
-
- for i in range(len(types)):
- # find the first char of the seg
- if flag == 0:
- temp_seg += text[i]
- temp_lang = types[i]
- flag = 1
- else:
- if temp_lang == "other":
- if types[i] == temp_lang:
- temp_seg += text[i]
- else:
- temp_seg += text[i]
- temp_lang = types[i]
- else:
- if types[i] == temp_lang:
- temp_seg += text[i]
- elif types[i] == "other":
- temp_seg += text[i]
- else:
- segments.append((temp_seg, temp_lang))
- temp_seg = text[i]
- temp_lang = types[i]
- flag = 1
-
- segments.append((temp_seg, temp_lang))
- return segments
-
-
-def preprocess(text: str) -> str:
- text = map_punctuations(text)
- return text
-
-
-def tokenize_ZH(text: str) -> List[str]:
- try:
- text = number_to_chinese(text)
- segs = list(jieba.cut(text))
- full = lazy_pinyin(
- segs, style=Style.TONE3, tone_sandhi=True, neutral_tone_with_five=True
- )
- phones = []
- for x in full:
- # valid pinyin (in tone3 style) is alphabet + 1 number in [1-5].
- if not (x[0:-1].isalpha() and x[-1] in ("1", "2", "3", "4", "5")):
- phones.append(x)
- continue
- initial = to_initials(x, strict=False)
- # don't want to share tokens with espeak tokens, so use tone3 style
- final = to_finals_tone3(x, strict=False, neutral_tone_with_five=True)
- if initial != "":
- # don't want to share tokens with espeak tokens, so add a '0' after each initial
- phones.append(initial + "0")
- if final != "":
- phones.append(final)
- return phones
- except Exception as ex:
- logging.warning(f"Tokenize ZH failed: {ex}")
- return []
-
-
-def tokenize_EN(text: str) -> List[str]:
- try:
- text = expand_abbreviations(text)
- text = normalize_numbers(text)
- tokens = phonemize_espeak(text, "en-us")
- tokens = reduce(lambda x, y: x + y, tokens)
- return tokens
- except Exception as ex:
- logging.warning(f"Tokenize EN failed: {ex}")
- return []
-
-
-class TokenizerEmilia(object):
- def __init__(self, token_file: Optional[str] = None, token_type="phone"):
- """
- Args:
- tokens: the file that contains information that maps tokens to ids,
- which is a text file with '{token} {token_id}' per line.
- """
- assert (
- token_type == "phone"
- ), f"Only support phone tokenizer for Emilia, but get {token_type}."
- self.has_tokens = False
- if token_file is None:
- logging.debug(
- "Initialize Tokenizer without tokens file, will fail when map to ids."
- )
- return
- self.token2id: Dict[str, int] = {}
- with open(token_file, "r", encoding="utf-8") as f:
- for line in f.readlines():
- info = line.rstrip().split("\t")
- token, id = info[0], int(info[1])
- assert token not in self.token2id, token
- self.token2id[token] = id
- self.pad_id = self.token2id["_"] # padding
-
- self.vocab_size = len(self.token2id)
- self.has_tokens = True
-
- def texts_to_token_ids(
- self,
- texts: List[str],
- ) -> List[List[int]]:
- return self.tokens_to_token_ids(self.texts_to_tokens(texts))
-
- def texts_to_tokens(
- self,
- texts: List[str],
- ) -> List[List[str]]:
- """
- Args:
- texts:
- A list of transcripts.
- Returns:
- Return a list of a list of tokens [utterance][token]
- """
- for i in range(len(texts)):
- # Text normalization
- texts[i] = preprocess(texts[i])
-
- phoneme_list = []
- for text in texts:
- # now only en and ch
- segments = get_segment(text)
- all_phoneme = []
- for index in range(len(segments)):
- seg = segments[index]
- if seg[1] == "zh":
- phoneme = tokenize_ZH(seg[0])
- else:
- if seg[1] != "en":
- logging.warning(
- f"The lang should be en, given {seg[1]}, skipping segment : {seg}"
- )
- continue
- phoneme = tokenize_EN(seg[0])
- all_phoneme += phoneme
- phoneme_list.append(all_phoneme)
- return phoneme_list
-
- def tokens_to_token_ids(
- self,
- tokens: List[List[str]],
- intersperse_blank: bool = False,
- ) -> List[List[int]]:
- """
- Args:
- tokens_list:
- A list of token list, each corresponding to one utterance.
- intersperse_blank:
- Whether to intersperse blanks in the token sequence.
-
- Returns:
- Return a list of token id list [utterance][token_id]
- """
- assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
- token_ids = []
-
- for tks in tokens:
- ids = []
- for t in tks:
- if t not in self.token2id:
- logging.warning(f"Skip OOV {t}")
- continue
- ids.append(self.token2id[t])
-
- if intersperse_blank:
- ids = intersperse(ids, self.pad_id)
-
- token_ids.append(ids)
-
- return token_ids
-
-
-class TokenizerLibriTTS(object):
- def __init__(self, token_file: str, token_type: str):
- """
- Args:
- type: the type of tokenizer, e.g., bpe, char, phone.
- tokens: the file that contains information that maps tokens to ids,
- which is a text file with '{token} {token_id}' per line if type is
- char or phone, otherwise it is a bpe_model file.
- """
- self.type = token_type
- assert token_type in ["bpe", "char", "phone"]
- # Parse token file
-
- if token_type == "bpe":
- import sentencepiece as spm
-
- self.sp = spm.SentencePieceProcessor()
- self.sp.load(token_file)
- self.pad_id = self.sp.piece_to_id("")
- self.vocab_size = self.sp.get_piece_size()
- else:
- self.token2id: Dict[str, int] = {}
- with open(token_file, "r", encoding="utf-8") as f:
- for line in f.readlines():
- info = line.rstrip().split("\t")
- token, id = info[0], int(info[1])
- assert token not in self.token2id, token
- self.token2id[token] = id
- self.pad_id = self.token2id["_"] # padding
- self.vocab_size = len(self.token2id)
- try:
- from tacotron_cleaner.cleaners import custom_english_cleaners as cleaner
- except Exception as ex:
- raise RuntimeError(
- f"{ex}\nPlease run\n"
- "pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/"
- )
- self.cleaner = cleaner
-
- def texts_to_token_ids(
- self,
- texts: List[str],
- lang: str = "en-us",
- ) -> List[List[int]]:
- """
- Args:
- texts:
- A list of transcripts.
- intersperse_blank:
- Whether to intersperse blanks in the token sequence.
- Used when alignment is from MAS.
- lang:
- Language argument passed to phonemize_espeak().
-
- Returns:
- Return a list of token id list [utterance][token_id]
- """
- for i in range(len(texts)):
- # Text normalization
- texts[i] = self.cleaner(texts[i])
-
- if self.type == "bpe":
- token_ids_list = self.sp.encode(texts)
-
- elif self.type == "phone":
- token_ids_list = []
- for text in texts:
- tokens_list = phonemize_espeak(text.lower(), lang)
- tokens = []
- for t in tokens_list:
- tokens.extend(t)
- token_ids = []
- for t in tokens:
- if t not in self.token2id:
- logging.warning(f"Skip OOV {t}")
- continue
- token_ids.append(self.token2id[t])
-
- token_ids_list.append(token_ids)
- else:
- token_ids_list = []
- for text in texts:
- token_ids = []
- for t in text:
- if t not in self.token2id:
- logging.warning(f"Skip OOV {t}")
- continue
- token_ids.append(self.token2id[t])
-
- token_ids_list.append(token_ids)
-
- return token_ids_list
-
- def tokens_to_token_ids(
- self,
- tokens_list: List[str],
- ) -> List[List[int]]:
- """
- Args:
- tokens_list:
- A list of token list, each corresponding to one utterance.
-
- Returns:
- Return a list of token id list [utterance][token_id]
- """
- token_ids_list = []
-
- for tokens in tokens_list:
- token_ids = []
- for t in tokens:
- if t not in self.token2id:
- logging.warning(f"Skip OOV {t}")
- continue
- token_ids.append(self.token2id[t])
-
- token_ids_list.append(token_ids)
-
- return token_ids_list
-
-
-if __name__ == "__main__":
- text = "我们是5年小米人,是吗? Yes I think so! mr king, 5 years, from 2019 to 2024. 霍...啦啦啦超过90%的人咯...?!9204"
- tokenizer = TokenizerEmilia()
- tokens = tokenizer.texts_to_tokens([text])
- print(f"tokens : {tokens}")
- tokens2 = "|".join(tokens[0])
- print(f"tokens2 : {tokens2}")
- tokens2 = tokens2.split("|")
- assert tokens[0] == tokens2
diff --git a/egs/zipvoice/zipvoice/train_distill.py b/egs/zipvoice/zipvoice/train_distill.py
deleted file mode 100644
index 9e52a3790..000000000
--- a/egs/zipvoice/zipvoice/train_distill.py
+++ /dev/null
@@ -1,1043 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-"""
-This script trains a ZipVoice-Distill model starting from a ZipVoice model.
-It has two distillation stages.
-
-Usage:
-
-(1) The first distillation stage with a fixed ZipVoice model as the teacher.
-python3 zipvoice/train_distill.py \
- --world-size 8 \
- --use-fp16 1 \
- --tensorboard 1 \
- --dataset "emilia" \
- --base-lr 0.0005 \
- --max-duration 500 \
- --token-file "data/tokens_emilia.txt" \
- --manifest-dir "data/fbank" \
- --teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
- --num-updates 60000 \
- --distill-stage "first" \
- --exp-dir zipvoice/exp_zipvoice_distill_1stage
-
-(2) The second distillation stage with a EMA model as the teacher.
-python3 zipvoice/train_distill.py \
- --world-size 8 \
- --use-fp16 1 \
- --tensorboard 1 \
- --dataset "emilia" \
- --base-lr 0.0001 \
- --max-duration 500 \
- --token-file "data/tokens_emilia.txt" \
- --manifest-dir "data/fbank" \
- --teacher-model zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt \
- --num-updates 2000 \
- --distill-stage "second" \
- --exp-dir zipvoice/exp_zipvoice_distill
-"""
-
-import argparse
-import copy
-import logging
-import os
-import random
-from pathlib import Path
-from shutil import copyfile
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-import optim
-import torch
-import torch.multiprocessing as mp
-import torch.nn as nn
-from checkpoint import load_checkpoint, save_checkpoint
-from lhotse.cut import Cut, CutSet
-from lhotse.utils import fix_random_seed
-from model import get_distill_model, get_model
-from optim import FixedLRScheduler, ScaledAdam
-from tokenizer import TokenizerEmilia, TokenizerLibriTTS
-from torch import Tensor
-from torch.amp import GradScaler, autocast
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.optim import Optimizer
-from torch.utils.tensorboard import SummaryWriter
-from train_flow import add_model_arguments, get_params
-from tts_datamodule import TtsDataModule
-from utils import (
- condition_time_mask,
- get_adjusted_batch_count,
- prepare_input,
- set_batch_count,
-)
-
-from icefall import diagnostics
-from icefall.checkpoint import (
- remove_checkpoints,
- save_checkpoint_with_global_batch_idx,
- update_averaged_model,
-)
-from icefall.dist import cleanup_dist, setup_dist
-from icefall.hooks import register_inf_check_hooks
-from icefall.utils import (
- AttributeDict,
- MetricsTracker,
- get_parameter_groups_with_lrs,
- make_pad_mask,
- setup_logger,
- str2bool,
-)
-
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
-
-
-def get_parser():
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter
- )
-
- parser.add_argument(
- "--world-size",
- type=int,
- default=1,
- help="Number of GPUs for DDP training.",
- )
-
- parser.add_argument(
- "--master-port",
- type=int,
- default=12354,
- help="Master port to use for DDP training.",
- )
-
- parser.add_argument(
- "--tensorboard",
- type=str2bool,
- default=True,
- help="Should various information be logged in tensorboard.",
- )
-
- parser.add_argument(
- "--num-epochs",
- type=int,
- default=1,
- help="Number of epochs to train.",
- )
-
- parser.add_argument(
- "--num-updates",
- type=int,
- default=0,
- help="Number of updates to train, will ignore num_epochs if > 0.",
- )
-
- parser.add_argument(
- "--start-epoch",
- type=int,
- default=1,
- help="""Resume training from this epoch. It should be positive.
- If larger than 1, it will load checkpoint from
- exp-dir/epoch-{start_epoch-1}.pt
- """,
- )
-
- parser.add_argument(
- "--teacher-model",
- type=str,
- help="""Checkpoints of pre-trained teacher model""",
- )
-
- parser.add_argument(
- "--exp-dir",
- type=str,
- default="zipvoice/exp_zipvoice_distill",
- help="""The experiment dir.
- It specifies the directory where all training related
- files, e.g., checkpoints, log, etc, are saved
- """,
- )
-
- parser.add_argument(
- "--base-lr", type=float, default=0.001, help="The base learning rate."
- )
-
- parser.add_argument(
- "--ref-duration",
- type=float,
- default=50,
- help="Reference batch duration for purposes of adjusting batch counts for setting various "
- "schedules inside the model",
- )
-
- parser.add_argument(
- "--seed",
- type=int,
- default=42,
- help="The seed for random generators intended for reproducibility",
- )
-
- parser.add_argument(
- "--print-diagnostics",
- type=str2bool,
- default=False,
- help="Accumulate stats on activations, print them and exit.",
- )
-
- parser.add_argument(
- "--inf-check",
- type=str2bool,
- default=False,
- help="Add hooks to check for infinite module outputs and gradients.",
- )
-
- parser.add_argument(
- "--save-every-n",
- type=int,
- default=4000,
- help="""Save checkpoint after processing this number of batches"
- periodically. We save checkpoint to exp-dir/ whenever
- params.batch_idx_train % save_every_n == 0. The checkpoint filename
- has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
- Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
- end of each epoch where `xxx` is the epoch number counting from 1.
- """,
- )
-
- parser.add_argument(
- "--keep-last-k",
- type=int,
- default=30,
- help="""Only keep this number of checkpoints on disk.
- For instance, if it is 3, there are only 3 checkpoints
- in the exp-dir with filenames `checkpoint-xxx.pt`.
- It does not affect checkpoints with name `epoch-xxx.pt`.
- """,
- )
-
- parser.add_argument(
- "--average-period",
- type=int,
- default=200,
- help="""Update the averaged model, namely `model_avg`, after processing
- this number of batches. `model_avg` is a separate version of model,
- in which each floating-point parameter is the average of all the
- parameters from the start of training. Each time we take the average,
- we do: `model_avg = model * (average_period / batch_idx_train) +
- model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
- """,
- )
-
- parser.add_argument(
- "--use-fp16",
- type=str2bool,
- default=False,
- help="Whether to use half precision training.",
- )
-
- parser.add_argument(
- "--feat-scale",
- type=float,
- default=0.1,
- help="The scale factor of fbank feature",
- )
-
- parser.add_argument(
- "--ema-decay",
- type=float,
- default=0.9999,
- help="The EMA decay factor of target model in distillation.",
- )
- parser.add_argument(
- "--distill-stage",
- type=str,
- choices=["first", "second"],
- help="The stage of distillation.",
- )
-
- parser.add_argument(
- "--dataset",
- type=str,
- default="emilia",
- choices=["emilia", "libritts"],
- help="The used training dataset",
- )
-
- add_model_arguments(parser)
-
- return parser
-
-
-def ema(new_model, ema_model, decay):
- if isinstance(new_model, DDP):
- new_model = new_model.module
- if isinstance(ema_model, DDP):
- ema_model = ema_model.module
- new_model_dict = new_model.state_dict()
- ema_model_dict = ema_model.state_dict()
- for key in new_model_dict.keys():
- ema_model_dict[key].data.copy_(
- ema_model_dict[key].data * decay + new_model_dict[key].data * (1 - decay)
- )
-
-
-def resume_checkpoint(
- params: AttributeDict, model: nn.Module, model_avg: nn.Module, model_ema: nn.Module
-) -> Optional[Dict[str, Any]]:
- """Load checkpoint from file.
-
- If params.start_epoch is larger than 1, it will load the checkpoint from
- `params.start_epoch - 1`.
-
- Apart from loading state dict for `model` and `optimizer` it also updates
- `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
- and `best_valid_loss` in `params`.
-
- Args:
- params:
- The return value of :func:`get_params`.
- model:
- The training model.
- Returns:
- Return a dict containing previously saved training info.
- """
- filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
-
- assert filename.is_file(), f"{filename} does not exist!"
-
- saved_params = load_checkpoint(
- filename, model=model, model_avg=model_avg, model_ema=model_ema, strict=True
- )
-
- if params.start_epoch > 1:
- keys = [
- "best_train_epoch",
- "best_valid_epoch",
- "batch_idx_train",
- "best_train_loss",
- "best_valid_loss",
- ]
- for k in keys:
- params[k] = saved_params[k]
-
- return saved_params
-
-
-def compute_fbank_loss(
- params: AttributeDict,
- model: Union[nn.Module, DDP],
- teacher_model: Union[nn.Module, DDP],
- features: Tensor,
- features_lens: Tensor,
- tokens: List[List[int]],
- is_training: bool,
-) -> Tuple[Tensor, MetricsTracker]:
- """
- Compute loss given the model and its inputs.
-
- Args:
- params:
- Parameters for training. See :func:`get_params`.
- model:
- The model for training.
- teacher_model:
- The teacher model for distillation.
- features:
- The target acoustic feature.
- features_lens:
- The number of frames of each utterance.
- tokens:
- Input tokens that representing the transcripts.
- durations:
- Duration of each token.
- is_training:
- True for training. False for validation. When it is True, this
- function enables autograd during computation; when it is False, it
- disables autograd.
- """
-
- device = model.device if isinstance(model, DDP) else next(model.parameters()).device
-
- batch_size, num_frames, _ = features.shape
-
- features = torch.nn.functional.pad(
- features, (0, 0, 0, num_frames - features.size(1))
- ) # (B, T, F)
- noise = torch.randn_like(features) # (B, T, F)
-
- # Sampling t and guidance_scale from uniform distribution
-
- t_value = random.random()
- t = torch.ones(batch_size, 1, 1, device=device) * t_value
- if params.distill_stage == "first":
- guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2
- else:
- guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2 + 1
- xt = features * t + noise * (1 - t)
- t_delta_fix = random.uniform(0.0, min(0.3, 1 - t_value))
- t_delta_ema = random.uniform(0.0, min(0.3, 1 - t_value - t_delta_fix))
- t_dest = t + t_delta_fix + t_delta_ema
-
- with torch.no_grad():
- speech_condition_mask = condition_time_mask(
- features_lens=features_lens,
- mask_percent=(0.7, 1.0),
- max_len=features.size(1),
- )
-
- if params.distill_stage == "first":
- teacher_x_t_mid, _ = teacher_model.sample_intermediate(
- tokens=tokens,
- features=features,
- features_lens=features_lens,
- noise=xt,
- speech_condition_mask=speech_condition_mask,
- t_start=t,
- t_end=t + t_delta_fix,
- num_step=1,
- guidance_scale=guidance_scale,
- )
-
- target_x1, _ = teacher_model.sample_intermediate(
- tokens=tokens,
- features=features,
- features_lens=features_lens,
- noise=teacher_x_t_mid,
- speech_condition_mask=speech_condition_mask,
- t_start=t + t_delta_fix,
- t_end=t_dest,
- num_step=1,
- guidance_scale=guidance_scale,
- )
- else:
- teacher_x_t_mid, _ = teacher_model(
- tokens=tokens,
- features=features,
- features_lens=features_lens,
- noise=xt,
- speech_condition_mask=speech_condition_mask,
- t_start=t,
- t_end=t + t_delta_fix,
- num_step=1,
- guidance_scale=guidance_scale,
- )
-
- target_x1, _ = teacher_model(
- tokens=tokens,
- features=features,
- features_lens=features_lens,
- noise=teacher_x_t_mid,
- speech_condition_mask=speech_condition_mask,
- t_start=t + t_delta_fix,
- t_end=t_dest,
- num_step=1,
- guidance_scale=guidance_scale,
- )
-
- with torch.set_grad_enabled(is_training):
-
- pred_x1, _ = model(
- tokens=tokens,
- features=features,
- features_lens=features_lens,
- noise=xt,
- speech_condition_mask=speech_condition_mask,
- t_start=t,
- t_end=t_dest,
- num_step=1,
- guidance_scale=guidance_scale,
- )
- pred_v = (pred_x1 - xt) / (t_dest - t)
-
- padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T)
- loss_mask = speech_condition_mask & (~padding_mask)
-
- target_v = (target_x1 - xt) / (t_dest - t)
- loss = torch.mean((pred_v[loss_mask] - target_v[loss_mask]) ** 2)
-
- ut = features - noise # (B, T, F)
-
- ref_loss = torch.mean((pred_v[loss_mask] - ut[loss_mask]) ** 2)
-
- assert loss.requires_grad == is_training
- info = MetricsTracker()
- num_frames = features_lens.sum().item()
- info["frames"] = num_frames
- info["loss"] = loss.detach().cpu().item() * num_frames
- info["ref_loss"] = ref_loss.detach().cpu().item() * num_frames
- return loss, info
-
-
-def train_one_epoch(
- params: AttributeDict,
- model: Union[nn.Module, DDP],
- teacher_model: Union[nn.Module, DDP],
- tokenizer: TokenizerEmilia,
- optimizer: Optimizer,
- scheduler: LRSchedulerType,
- train_dl: torch.utils.data.DataLoader,
- valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
- model_avg: Optional[nn.Module] = None,
- tb_writer: Optional[SummaryWriter] = None,
- world_size: int = 1,
- rank: int = 0,
-) -> None:
- """Train the model for one epoch.
-
- The training loss from the mean of all frames is saved in
- `params.train_loss`. It runs the validation process every
- `params.valid_interval` batches.
-
- Args:
- params:
- It is returned by :func:`get_params`.
- model:
- The model for training.
- teacher_model:
- The model for distillation.
- tokenizer:
- Used to convert text to tokens.
- optimizer:
- The optimizer.
- scheduler:
- The learning rate scheduler, we call step() every epoch.
- train_dl:
- Dataloader for the training dataset.
- valid_dl:
- Dataloader for the validation dataset.
- scaler:
- The scaler used for mix precision training.
- tb_writer:
- Writer to write log messages to tensorboard.
- world_size:
- Number of nodes in DDP training. If it is 1, DDP is disabled.
- rank:
- The rank of the node in DDP training. If no DDP is used, it should
- be set to 0.
- """
- model.train()
- device = model.device if isinstance(model, DDP) else next(model.parameters()).device
-
- # used to track the stats over iterations in one epoch
- tot_loss = MetricsTracker()
-
- saved_bad_model = False
-
- def save_bad_model(suffix: str = ""):
- save_checkpoint(
- filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
- model=model,
- model_avg=model_avg,
- model_ema=teacher_model,
- params=params,
- optimizer=optimizer,
- scheduler=scheduler,
- sampler=train_dl.sampler,
- scaler=scaler,
- rank=0,
- )
-
- for batch_idx, batch in enumerate(train_dl):
-
- if batch_idx % 10 == 0:
- set_batch_count(model, get_adjusted_batch_count(params) + 100000)
-
- if (
- params.valid_interval is None
- and batch_idx == 0
- and not params.print_diagnostics
- ) or (
- params.valid_interval is not None
- and params.batch_idx_train % params.valid_interval == 0
- and not params.print_diagnostics
- ):
- logging.info("Computing validation loss")
- valid_info = compute_validation_loss(
- params=params,
- model=model,
- teacher_model=teacher_model,
- tokenizer=tokenizer,
- valid_dl=valid_dl,
- world_size=world_size,
- )
- model.train()
- logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
- logging.info(
- f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
- )
- if tb_writer is not None:
- valid_info.write_summary(
- tb_writer, "train/valid_", params.batch_idx_train
- )
-
- params.batch_idx_train += 1
-
- batch_size = len(batch["text"])
-
- tokens, features, features_lens = prepare_input(
- params=params,
- batch=batch,
- device=device,
- tokenizer=tokenizer,
- return_tokens=True,
- return_feature=True,
- )
-
- try:
- with autocast("cuda", enabled=params.use_fp16):
- loss, loss_info = compute_fbank_loss(
- params=params,
- model=model,
- teacher_model=teacher_model,
- features=features,
- features_lens=features_lens,
- tokens=tokens,
- is_training=True,
- )
-
- tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
-
- scaler.scale(loss).backward()
-
- scheduler.step_batch(params.batch_idx_train)
- scaler.step(optimizer)
- scaler.update()
- optimizer.zero_grad()
- if params.distill_stage == "second":
- ema(model, teacher_model, params.ema_decay)
- except RuntimeError as e:
- if "out of memory" in str(e):
- logging.info(f"out of memory error at rank {rank}")
- # optimizer.zero_grad()
- # duration_optimizer.zero_grad()
- torch.cuda.empty_cache()
- raise
- continue
- else:
- logging.info(f"Caught exception : {e}.")
- save_bad_model()
- raise
- except Exception as e:
- logging.info(f"Caught exception : {e}.")
- save_bad_model()
- raise
-
- if params.print_diagnostics and batch_idx == 5:
- return
-
- if (
- rank == 0
- and params.batch_idx_train > 0
- and params.batch_idx_train % params.average_period == 0
- ):
- update_averaged_model(
- params=params,
- model_cur=model,
- model_avg=model_avg,
- )
-
- if (
- params.batch_idx_train > 0
- and params.batch_idx_train % params.save_every_n == 0
- ):
- save_checkpoint_with_global_batch_idx(
- out_dir=params.exp_dir,
- global_batch_idx=params.batch_idx_train,
- model=model,
- model_avg=model_avg,
- params=params,
- optimizer=optimizer,
- scheduler=scheduler,
- sampler=train_dl.sampler,
- scaler=scaler,
- rank=rank,
- )
- remove_checkpoints(
- out_dir=params.exp_dir,
- topk=params.keep_last_k,
- rank=rank,
- )
- if (
- params.batch_idx_train > 0
- and params.num_updates > 0
- and params.batch_idx_train > params.num_updates
- ):
- break
- if params.batch_idx_train % 100 == 0 and params.use_fp16:
- # If the grad scale was less than 1, try increasing it. The _growth_interval
- # of the grad scaler is configurable, but we can't configure it to have different
- # behavior depending on the current grad scale.
- cur_grad_scale = scaler._scale.item()
-
- if cur_grad_scale < 1024.0 or (
- cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
- ):
- scaler.update(cur_grad_scale * 2.0)
- if cur_grad_scale < 0.01:
- if not saved_bad_model:
- save_bad_model(suffix="-first-warning")
- saved_bad_model = True
- logging.warning(f"Grad scale is small: {cur_grad_scale}")
- if cur_grad_scale < 1.0e-05:
- save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
-
- if params.batch_idx_train % params.log_interval == 0:
- cur_lr = max(scheduler.get_last_lr())
- cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
-
- logging.info(
- f"Epoch {params.cur_epoch}, batch {batch_idx}, "
- f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
- f"loss[{loss_info}], tot_loss[{tot_loss}], "
- f"cur_lr: {cur_lr:.2e}, "
- + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
- )
-
- if tb_writer is not None:
- tb_writer.add_scalar(
- "train/learning_rate", cur_lr, params.batch_idx_train
- )
- loss_info.write_summary(
- tb_writer, "train/current_", params.batch_idx_train
- )
- tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
- if params.use_fp16:
- tb_writer.add_scalar(
- "train/grad_scale", cur_grad_scale, params.batch_idx_train
- )
-
- loss_value = tot_loss["loss"]
- params.train_loss = loss_value
- if params.train_loss < params.best_train_loss:
- params.best_train_epoch = params.cur_epoch
- params.best_train_loss = params.train_loss
-
-
-def compute_validation_loss(
- params: AttributeDict,
- model: Union[nn.Module, DDP],
- teacher_model: Optional[nn.Module],
- tokenizer: TokenizerEmilia,
- valid_dl: torch.utils.data.DataLoader,
- world_size: int = 1,
-) -> MetricsTracker:
- """Run the validation process."""
-
- model.eval()
- device = model.device if isinstance(model, DDP) else next(model.parameters()).device
-
- # used to summary the stats over iterations
- tot_loss = MetricsTracker()
-
- for batch_idx, batch in enumerate(valid_dl):
- tokens, features, features_lens = prepare_input(
- params=params,
- batch=batch,
- device=device,
- tokenizer=tokenizer,
- return_tokens=True,
- return_feature=True,
- )
-
- loss, loss_info = compute_fbank_loss(
- params=params,
- model=model,
- teacher_model=teacher_model,
- features=features,
- features_lens=features_lens,
- tokens=tokens,
- is_training=False,
- )
- assert loss.requires_grad is False
- tot_loss = tot_loss + loss_info
-
- if world_size > 1:
- tot_loss.reduce(loss.device)
-
- loss_value = tot_loss["loss"]
- if loss_value < params.best_valid_loss:
- params.best_valid_epoch = params.cur_epoch
- params.best_valid_loss = loss_value
-
- return tot_loss
-
-
-def run(rank, world_size, args):
- """
- Args:
- rank:
- It is a value between 0 and `world_size-1`, which is
- passed automatically by `mp.spawn()` in :func:`main`.
- The node with rank 0 is responsible for saving checkpoint.
- world_size:
- Number of GPUs for DDP training.
- args:
- The return value of get_parser().parse_args()
- """
- params = get_params()
- params.update(vars(args))
-
- fix_random_seed(params.seed)
- if world_size > 1:
- setup_dist(rank, world_size, params.master_port)
-
- setup_logger(f"{params.exp_dir}/log/log-train")
- os.makedirs(f"{params.exp_dir}/fbank", exist_ok=True)
-
- logging.info("Training started")
-
- if args.tensorboard and rank == 0:
- tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
- else:
- tb_writer = None
-
- device = torch.device("cpu")
- if torch.cuda.is_available():
- device = torch.device("cuda", rank)
- logging.info(f"Device: {device}")
-
- if params.dataset == "emilia":
- tokenizer = TokenizerEmilia(
- token_file=params.token_file, token_type=params.token_type
- )
- elif params.dataset == "libritts":
- tokenizer = TokenizerLibriTTS(
- token_file=params.token_file, token_type=params.token_type
- )
-
- params.vocab_size = tokenizer.vocab_size
- params.pad_id = tokenizer.pad_id
-
- params.device = device
-
- logging.info(params)
-
- logging.info("About to create model")
-
- assert params.teacher_model is not None
- logging.info(f"Loading pre-trained model from {params.teacher_model}")
- model = get_distill_model(params)
- _ = load_checkpoint(
- filename=params.teacher_model,
- model=model,
- strict=(params.distill_stage == "second"),
- )
-
- if params.distill_stage == "first":
- teacher_model = get_model(params)
- _ = load_checkpoint(
- filename=params.teacher_model, model=teacher_model, strict=True
- )
- else:
- teacher_model = copy.deepcopy(model)
-
- num_param = sum([p.numel() for p in model.parameters()])
- logging.info(f"Number of parameters : {num_param}")
-
- model_avg: Optional[nn.Module] = None
- if rank == 0:
- # model_avg is only used with rank 0
- model_avg = copy.deepcopy(model).to(torch.float64)
- assert params.start_epoch > 0, params.start_epoch
- if params.start_epoch > 1:
- logging.info(f"Resuming from epoch {params.start_epoch}")
- if params.distill_stage == "first":
- checkpoints = resume_checkpoint(
- params=params, model=model, model_avg=model_avg
- )
- else:
- checkpoints = resume_checkpoint(
- params=params, model=model, model_avg=model_avg, model_ema=teacher_model
- )
-
- model = model.to(device)
- teacher_model.to(device)
- teacher_model.eval()
-
- if world_size > 1:
- logging.info("Using DDP")
- model = DDP(model, device_ids=[rank], find_unused_parameters=True)
-
- # only update the fm_decoder
- num_trainable = 0
- for name, p in model.named_parameters():
- if "fm_decoder" in name:
- p.requires_grad = True
- num_trainable += p.numel()
- else:
- p.requires_grad = False
-
- logging.info(
- "A total of {} trainable parameters ({:.3f}% of the whole model)".format(
- num_trainable, num_trainable / num_param * 100
- )
- )
-
- optimizer = ScaledAdam(
- get_parameter_groups_with_lrs(
- model,
- lr=params.base_lr,
- include_names=True,
- ),
- lr=params.base_lr, # should have no effect
- clipping_scale=2.0,
- )
-
- scheduler = FixedLRScheduler(optimizer)
-
- scaler = GradScaler("cuda", enabled=params.use_fp16)
-
- if params.start_epoch > 1 and checkpoints is not None:
- # load state_dict for optimizers
- if "optimizer" in checkpoints:
- logging.info("Loading optimizer state dict")
- optimizer.load_state_dict(checkpoints["optimizer"])
-
- # load state_dict for schedulers
- if "scheduler" in checkpoints:
- logging.info("Loading scheduler state dict")
- scheduler.load_state_dict(checkpoints["scheduler"])
-
- if "grad_scaler" in checkpoints:
- logging.info("Loading grad scaler state dict")
- scaler.load_state_dict(checkpoints["grad_scaler"])
-
- if params.print_diagnostics:
- opts = diagnostics.TensorDiagnosticOptions(
- 512
- ) # allow 4 megabytes per sub-module
- diagnostic = diagnostics.attach_diagnostics(model, opts)
-
- if params.inf_check:
- register_inf_check_hooks(model)
-
- def remove_short_and_long_utt_emilia(c: Cut):
- if c.duration < 1.0 or c.duration > 30.0:
- return False
- return True
-
- def remove_short_and_long_utt_libritts(c: Cut):
- if c.duration < 1.0 or c.duration > 20.0:
- return False
- return True
-
- datamodule = TtsDataModule(args)
- if params.dataset == "emilia":
- train_cuts = CutSet.mux(
- datamodule.train_emilia_EN_cuts(),
- datamodule.train_emilia_ZH_cuts(),
- weights=[46000, 49000],
- )
- train_cuts = train_cuts.filter(remove_short_and_long_utt_emilia)
- dev_cuts = CutSet.mux(
- datamodule.dev_emilia_EN_cuts(),
- datamodule.dev_emilia_ZH_cuts(),
- weights=[0.5, 0.5],
- )
- elif params.dataset == "libritts":
- train_cuts = datamodule.train_libritts_cuts()
- train_cuts = train_cuts.filter(remove_short_and_long_utt_libritts)
- dev_cuts = datamodule.dev_libritts_cuts()
-
- train_dl = datamodule.train_dataloaders(train_cuts)
-
- valid_dl = datamodule.dev_dataloaders(dev_cuts)
-
- for epoch in range(params.start_epoch, params.num_epochs + 1):
- logging.info(f"Start epoch {epoch}")
-
- scheduler.step_epoch(epoch - 1)
- fix_random_seed(params.seed + epoch - 1)
- train_dl.sampler.set_epoch(epoch - 1)
-
- params.cur_epoch = epoch
-
- if tb_writer is not None:
- tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
-
- train_one_epoch(
- params=params,
- model=model,
- model_avg=model_avg,
- teacher_model=teacher_model,
- tokenizer=tokenizer,
- optimizer=optimizer,
- scheduler=scheduler,
- train_dl=train_dl,
- valid_dl=valid_dl,
- scaler=scaler,
- tb_writer=tb_writer,
- world_size=world_size,
- rank=rank,
- )
-
- if params.print_diagnostics:
- diagnostic.print_diagnostics()
- break
-
- filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
- save_checkpoint(
- filename=filename,
- params=params,
- model=model,
- model_avg=model_avg,
- model_ema=teacher_model,
- optimizer=optimizer,
- scheduler=scheduler,
- sampler=train_dl.sampler,
- scaler=scaler,
- rank=rank,
- )
-
- if rank == 0:
- if params.best_train_epoch == params.cur_epoch:
- best_train_filename = params.exp_dir / "best-train-loss.pt"
- copyfile(src=filename, dst=best_train_filename)
-
- if params.best_valid_epoch == params.cur_epoch:
- best_valid_filename = params.exp_dir / "best-valid-loss.pt"
- copyfile(src=filename, dst=best_valid_filename)
-
- logging.info("Done!")
-
- if world_size > 1:
- torch.distributed.barrier()
- cleanup_dist()
-
-
-def main():
- parser = get_parser()
- TtsDataModule.add_arguments(parser)
- args = parser.parse_args()
- args.exp_dir = Path(args.exp_dir)
-
- world_size = args.world_size
- assert world_size >= 1
- if world_size > 1:
- mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
- else:
- run(rank=0, world_size=1, args=args)
-
-
-if __name__ == "__main__":
- torch.set_num_threads(1)
- torch.set_num_interop_threads(1)
- main()
diff --git a/egs/zipvoice/zipvoice/train_flow.py b/egs/zipvoice/zipvoice/train_flow.py
deleted file mode 100644
index 0bf023273..000000000
--- a/egs/zipvoice/zipvoice/train_flow.py
+++ /dev/null
@@ -1,1108 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2023 Xiaomi Corp. (authors: Wei Kang,
-# Han Zhu)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-This script trains a ZipVoice model with the flow-matching loss.
-
-Usage:
-
-python3 zipvoice/train_flow.py \
- --world-size 8 \
- --use-fp16 1 \
- --dataset emilia \
- --max-duration 500 \
- --lr-hours 30000 \
- --lr-batches 7500 \
- --token-file "data/tokens_emilia.txt" \
- --manifest-dir "data/fbank" \
- --num-epochs 11 \
- --exp-dir zipvoice/exp_zipvoice
-"""
-
-import argparse
-import copy
-import logging
-import os
-from pathlib import Path
-from shutil import copyfile
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-import optim
-import torch
-import torch.multiprocessing as mp
-import torch.nn as nn
-from checkpoint import load_checkpoint, save_checkpoint
-from lhotse.cut import Cut, CutSet
-from lhotse.utils import fix_random_seed
-from model import get_model
-from optim import Eden, ScaledAdam
-from tokenizer import TokenizerEmilia, TokenizerLibriTTS
-from torch import Tensor
-from torch.amp import GradScaler, autocast
-from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.optim import Optimizer
-from torch.utils.tensorboard import SummaryWriter
-from tts_datamodule import TtsDataModule
-from utils import get_adjusted_batch_count, prepare_input, set_batch_count
-
-from icefall import diagnostics
-from icefall.checkpoint import (
- remove_checkpoints,
- save_checkpoint_with_global_batch_idx,
- update_averaged_model,
-)
-from icefall.dist import cleanup_dist, setup_dist
-from icefall.env import get_env_info
-from icefall.hooks import register_inf_check_hooks
-from icefall.utils import (
- AttributeDict,
- MetricsTracker,
- get_parameter_groups_with_lrs,
- setup_logger,
- str2bool,
-)
-
-LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
-
-
-def add_model_arguments(parser: argparse.ArgumentParser):
- parser.add_argument(
- "--fm-decoder-downsampling-factor",
- type=str,
- default="1,2,4,2,1",
- help="Downsampling factor for each stack of encoder layers.",
- )
-
- parser.add_argument(
- "--fm-decoder-num-layers",
- type=str,
- default="2,2,4,4,4",
- help="Number of zipformer encoder layers per stack, comma separated.",
- )
-
- parser.add_argument(
- "--fm-decoder-cnn-module-kernel",
- type=str,
- default="31,15,7,15,31",
- help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
- "a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--fm-decoder-feedforward-dim",
- type=int,
- default=1536,
- help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
- )
-
- parser.add_argument(
- "--fm-decoder-num-heads",
- type=int,
- default=4,
- help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--fm-decoder-dim",
- type=int,
- default=512,
- help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--text-encoder-downsampling-factor",
- type=str,
- default="1",
- help="Downsampling factor for each stack of encoder layers.",
- )
-
- parser.add_argument(
- "--text-encoder-num-layers",
- type=str,
- default="4",
- help="Number of zipformer encoder layers per stack, comma separated.",
- )
-
- parser.add_argument(
- "--text-encoder-feedforward-dim",
- type=int,
- default=512,
- help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
- )
-
- parser.add_argument(
- "--text-encoder-cnn-module-kernel",
- type=str,
- default="9",
- help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
- "a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--text-encoder-num-heads",
- type=int,
- default=4,
- help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--text-encoder-dim",
- type=int,
- default=192,
- help="Embedding dimension in encoder stacks: a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--query-head-dim",
- type=int,
- default=32,
- help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--value-head-dim",
- type=int,
- default=12,
- help="Value dimension per head in encoder stacks: a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--pos-head-dim",
- type=int,
- default=4,
- help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--pos-dim",
- type=int,
- default=48,
- help="Positional-encoding embedding dimension",
- )
-
- parser.add_argument(
- "--time-embed-dim",
- type=int,
- default=192,
- help="Embedding dimension of timestamps embedding.",
- )
-
- parser.add_argument(
- "--text-embed-dim",
- type=int,
- default=192,
- help="Embedding dimension of text embedding.",
- )
-
- parser.add_argument(
- "--token-type",
- type=str,
- default="phone",
- choices=["phone", "char", "bpe"],
- help="Input token type of TTS model, by default, "
- "we use phone for emilia, char for libritts.",
- )
-
- parser.add_argument(
- "--token-file",
- type=str,
- default="data/tokens_emilia.txt",
- help="The file that contains information that maps tokens to ids,"
- "which is a text file with '{token}\t{token_id}' per line if type is"
- "char or phone, otherwise it is a bpe_model file.",
- )
-
-
-def get_parser():
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter
- )
-
- parser.add_argument(
- "--world-size",
- type=int,
- default=1,
- help="Number of GPUs for DDP training.",
- )
-
- parser.add_argument(
- "--master-port",
- type=int,
- default=12354,
- help="Master port to use for DDP training.",
- )
-
- parser.add_argument(
- "--tensorboard",
- type=str2bool,
- default=True,
- help="Should various information be logged in tensorboard.",
- )
-
- parser.add_argument(
- "--num-epochs",
- type=int,
- default=11,
- help="Number of epochs to train.",
- )
-
- parser.add_argument(
- "--start-epoch",
- type=int,
- default=1,
- help="""Resume training from this epoch. It should be positive.
- If larger than 1, it will load checkpoint from
- exp-dir/epoch-{start_epoch-1}.pt
- """,
- )
-
- parser.add_argument(
- "--checkpoint",
- type=str,
- default=None,
- help="""Checkpoints of pre-trained models, will load it if not None
- """,
- )
-
- parser.add_argument(
- "--exp-dir",
- type=str,
- default="zipvoice/exp_zipvoice",
- help="""The experiment dir.
- It specifies the directory where all training related
- files, e.g., checkpoints, log, etc, are saved
- """,
- )
-
- parser.add_argument(
- "--base-lr", type=float, default=0.02, help="The base learning rate."
- )
-
- parser.add_argument(
- "--lr-batches",
- type=float,
- default=7500,
- help="""Number of steps that affects how rapidly the learning rate
- decreases. We suggest not to change this.""",
- )
-
- parser.add_argument(
- "--lr-epochs",
- type=float,
- default=10,
- help="""Number of epochs that affects how rapidly the learning rate decreases.
- """,
- )
-
- parser.add_argument(
- "--lr-hours",
- type=float,
- default=0,
- help="""If positive, --epoch is ignored and it specifies the number of hours
- that affects how rapidly the learning rate decreases.
- """,
- )
-
- parser.add_argument(
- "--ref-duration",
- type=float,
- default=50,
- help="Reference batch duration for purposes of adjusting batch counts for setting various "
- "schedules inside the model",
- )
-
- parser.add_argument(
- "--seed",
- type=int,
- default=42,
- help="The seed for random generators intended for reproducibility",
- )
-
- parser.add_argument(
- "--print-diagnostics",
- type=str2bool,
- default=False,
- help="Accumulate stats on activations, print them and exit.",
- )
-
- parser.add_argument(
- "--inf-check",
- type=str2bool,
- default=False,
- help="Add hooks to check for infinite module outputs and gradients.",
- )
-
- parser.add_argument(
- "--save-every-n",
- type=int,
- default=4000,
- help="""Save checkpoint after processing this number of batches"
- periodically. We save checkpoint to exp-dir/ whenever
- params.batch_idx_train % save_every_n == 0. The checkpoint filename
- has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
- Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
- end of each epoch where `xxx` is the epoch number counting from 1.
- """,
- )
-
- parser.add_argument(
- "--keep-last-k",
- type=int,
- default=30,
- help="""Only keep this number of checkpoints on disk.
- For instance, if it is 3, there are only 3 checkpoints
- in the exp-dir with filenames `checkpoint-xxx.pt`.
- It does not affect checkpoints with name `epoch-xxx.pt`.
- """,
- )
-
- parser.add_argument(
- "--average-period",
- type=int,
- default=200,
- help="""Update the averaged model, namely `model_avg`, after processing
- this number of batches. `model_avg` is a separate version of model,
- in which each floating-point parameter is the average of all the
- parameters from the start of training. Each time we take the average,
- we do: `model_avg = model * (average_period / batch_idx_train) +
- model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
- """,
- )
-
- parser.add_argument(
- "--use-fp16",
- type=str2bool,
- default=True,
- help="Whether to use half precision training.",
- )
-
- parser.add_argument(
- "--feat-scale",
- type=float,
- default=0.1,
- help="The scale factor of fbank feature",
- )
-
- parser.add_argument(
- "--condition-drop-ratio",
- type=float,
- default=0.2,
- help="The drop rate of text condition during training.",
- )
-
- parser.add_argument(
- "--dataset",
- type=str,
- default="emilia",
- choices=["emilia", "libritts"],
- help="The used training dataset",
- )
-
- add_model_arguments(parser)
-
- return parser
-
-
-def get_params() -> AttributeDict:
- """Return a dict containing training parameters.
-
- All training related parameters that are not passed from the commandline
- are saved in the variable `params`.
-
- Commandline options are merged into `params` after they are parsed, so
- you can also access them via `params`.
-
- Explanation of options saved in `params`:
-
- - best_train_loss: Best training loss so far. It is used to select
- the model that has the lowest training loss. It is
- updated during the training.
-
- - best_valid_loss: Best validation loss so far. It is used to select
- the model that has the lowest validation loss. It is
- updated during the training.
-
- - best_train_epoch: It is the epoch that has the best training loss.
-
- - best_valid_epoch: It is the epoch that has the best validation loss.
-
- - batch_idx_train: Used to writing statistics to tensorboard. It
- contains number of batches trained so far across
- epochs.
-
- - log_interval: Print training loss if batch_idx % log_interval` is 0
-
- - reset_interval: Reset statistics if batch_idx % reset_interval is 0
-
- - valid_interval: Run validation if batch_idx % valid_interval is 0
-
- - sampling_rate: Sampling rate of the wavform.
-
- - frame_shift_ms: Frame shift in milliseconds.
-
- - feat_dim: The model input dim. It has to match the one used
- in computing features.
-
- - env_info: A dict containing information about the environment.
-
- """
- params = AttributeDict(
- {
- "best_train_loss": float("inf"),
- "best_valid_loss": float("inf"),
- "best_train_epoch": -1,
- "best_valid_epoch": -1,
- "batch_idx_train": 0,
- "log_interval": 50,
- "reset_interval": 200,
- "valid_interval": 4000,
- "sampling_rate": 24000,
- "frame_shift_ms": 256 / 24000 * 1000,
- "feat_dim": 100,
- "env_info": get_env_info(),
- }
- )
-
- return params
-
-
-def resume_checkpoint(
- params: AttributeDict, model: nn.Module, model_avg: nn.Module
-) -> Optional[Dict[str, Any]]:
- """Load checkpoint from file.
-
- If params.start_epoch is larger than 1, it will load the checkpoint from
- `params.start_epoch - 1`.
-
- Apart from loading state dict for `model` it also updates
- `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
- and `best_valid_loss` in `params`.
-
- Args:
- params:
- The return value of :func:`get_params`.
- model:
- The training model.
- Returns:
- Return a dict containing previously saved training info.
- """
- if params.start_epoch > 1:
- filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
- else:
- return None
-
- logging.info(f"Resuming from file {filename}")
- assert filename.is_file(), f"{filename} does not exist!"
-
- saved_params = load_checkpoint(
- filename, model=model, model_avg=model_avg, strict=True
- )
-
- keys = [
- "best_train_epoch",
- "best_valid_epoch",
- "batch_idx_train",
- "best_train_loss",
- "best_valid_loss",
- ]
- for k in keys:
- params[k] = saved_params[k]
-
- return saved_params
-
-
-def compute_fbank_loss(
- params: AttributeDict,
- model: Union[nn.Module, DDP],
- features: Tensor,
- features_lens: Tensor,
- tokens: List[List[int]],
- is_training: bool,
-) -> Tuple[Tensor, MetricsTracker]:
- """
- Compute loss given the model and its inputs.
-
- Args:
- params:
- Parameters for training. See :func:`get_params`.
- model:
- The model for training.
- features:
- The target acoustic feature.
- features_lens:
- The number of frames of each utterance.
- tokens:
- Input tokens that representing the transcripts.
- is_training:
- True for training. False for validation. When it is True, this
- function enables autograd during computation; when it is False, it
- disables autograd.
- """
-
- device = model.device if isinstance(model, DDP) else next(model.parameters()).device
-
- batch_size, num_frames, _ = features.shape
-
- features = torch.nn.functional.pad(
- features, (0, 0, 0, num_frames - features.size(1))
- ) # (B, T, F)
- noise = torch.randn_like(features) # (B, T, F)
-
- # Sampling t from uniform distribution
- if is_training:
- t = torch.rand(batch_size, 1, 1, device=device)
- else:
- t = (
- (torch.arange(batch_size, device=device) / batch_size)
- .unsqueeze(1)
- .unsqueeze(2)
- )
- with torch.set_grad_enabled(is_training):
-
- loss = model(
- tokens=tokens,
- features=features,
- features_lens=features_lens,
- noise=noise,
- t=t,
- condition_drop_ratio=params.condition_drop_ratio,
- )
-
- assert loss.requires_grad == is_training
- info = MetricsTracker()
- num_frames = features_lens.sum().item()
- info["frames"] = num_frames
- info["loss"] = loss.detach().cpu().item() * num_frames
-
- return loss, info
-
-
-def train_one_epoch(
- params: AttributeDict,
- model: Union[nn.Module, DDP],
- tokenizer: TokenizerEmilia,
- optimizer: Optimizer,
- scheduler: LRSchedulerType,
- train_dl: torch.utils.data.DataLoader,
- valid_dl: torch.utils.data.DataLoader,
- scaler: GradScaler,
- model_avg: Optional[nn.Module] = None,
- tb_writer: Optional[SummaryWriter] = None,
- world_size: int = 1,
- rank: int = 0,
-) -> None:
- """Train the model for one epoch.
-
- The training loss from the mean of all frames is saved in
- `params.train_loss`. It runs the validation process every
- `params.valid_interval` batches.
-
- Args:
- params:
- It is returned by :func:`get_params`.
- model:
- The model for training.
- tokenizer:
- Used to convert text to tokens.
- optimizer:
- The optimizer.
- scheduler:
- The learning rate scheduler, we call step() every epoch.
- train_dl:
- Dataloader for the training dataset.
- valid_dl:
- Dataloader for the validation dataset.
- scaler:
- The scaler used for mix precision training.
- tb_writer:
- Writer to write log messages to tensorboard.
- world_size:
- Number of nodes in DDP training. If it is 1, DDP is disabled.
- rank:
- The rank of the node in DDP training. If no DDP is used, it should
- be set to 0.
- """
- model.train()
- device = model.device if isinstance(model, DDP) else next(model.parameters()).device
-
- # used to track the stats over iterations in one epoch
- tot_loss = MetricsTracker()
-
- saved_bad_model = False
-
- def save_bad_model(suffix: str = ""):
- save_checkpoint(
- filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
- model=model,
- model_avg=model_avg,
- params=params,
- optimizer=optimizer,
- scheduler=scheduler,
- sampler=train_dl.sampler,
- scaler=scaler,
- rank=0,
- )
-
- for batch_idx, batch in enumerate(train_dl):
-
- if batch_idx % 10 == 0:
- set_batch_count(model, get_adjusted_batch_count(params))
-
- if (
- params.valid_interval is None
- and batch_idx == 0
- and not params.print_diagnostics
- ) or (
- params.valid_interval is not None
- and params.batch_idx_train % params.valid_interval == 0
- and not params.print_diagnostics
- ):
- logging.info("Computing validation loss")
- valid_info = compute_validation_loss(
- params=params,
- model=model,
- tokenizer=tokenizer,
- valid_dl=valid_dl,
- world_size=world_size,
- )
- model.train()
- logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
- logging.info(
- f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
- )
- if tb_writer is not None:
- valid_info.write_summary(
- tb_writer, "train/valid_", params.batch_idx_train
- )
-
- params.batch_idx_train += 1
-
- batch_size = len(batch["text"])
-
- tokens, features, features_lens = prepare_input(
- params=params,
- batch=batch,
- device=device,
- tokenizer=tokenizer,
- return_tokens=True,
- return_feature=True,
- )
-
- try:
- with autocast("cuda", enabled=params.use_fp16):
- loss, loss_info = compute_fbank_loss(
- params=params,
- model=model,
- features=features,
- features_lens=features_lens,
- tokens=tokens,
- is_training=True,
- )
-
- tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
-
- scaler.scale(loss).backward()
-
- scheduler.step_batch(params.batch_idx_train)
- # Use the number of hours of speech to adjust the learning rate
- if params.lr_hours > 0:
- scheduler.step_epoch(
- params.batch_idx_train
- * params.max_duration
- * params.world_size
- / 3600
- )
- scaler.step(optimizer)
- scaler.update()
- optimizer.zero_grad()
- except RuntimeError as e:
- if "out of memory" in str(e):
- logging.info(f"out of memory error at rank {rank}")
- # optimizer.zero_grad()
- # duration_optimizer.zero_grad()
- torch.cuda.empty_cache()
- raise
- continue
- else:
- logging.info(f"Caught exception : {e}.")
- save_bad_model()
- raise
- except Exception as e:
- logging.info(f"Caught exception : {e}.")
- save_bad_model()
- raise
-
- if params.print_diagnostics and batch_idx == 5:
- return
-
- if (
- rank == 0
- and params.batch_idx_train > 0
- and params.batch_idx_train % params.average_period == 0
- ):
- update_averaged_model(
- params=params,
- model_cur=model,
- model_avg=model_avg,
- )
-
- if (
- params.batch_idx_train > 0
- and params.batch_idx_train % params.save_every_n == 0
- ):
- save_checkpoint_with_global_batch_idx(
- out_dir=params.exp_dir,
- global_batch_idx=params.batch_idx_train,
- model=model,
- model_avg=model_avg,
- params=params,
- optimizer=optimizer,
- scheduler=scheduler,
- sampler=train_dl.sampler,
- scaler=scaler,
- rank=rank,
- )
- remove_checkpoints(
- out_dir=params.exp_dir,
- topk=params.keep_last_k,
- rank=rank,
- )
- if params.batch_idx_train % 100 == 0 and params.use_fp16:
- # If the grad scale was less than 1, try increasing it. The _growth_interval
- # of the grad scaler is configurable, but we can't configure it to have different
- # behavior depending on the current grad scale.
- cur_grad_scale = scaler._scale.item()
-
- if cur_grad_scale < 1024.0 or (
- cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
- ):
- scaler.update(cur_grad_scale * 2.0)
- if cur_grad_scale < 0.01:
- if not saved_bad_model:
- save_bad_model(suffix="-first-warning")
- saved_bad_model = True
- logging.warning(f"Grad scale is small: {cur_grad_scale}")
- if cur_grad_scale < 1.0e-05:
- save_bad_model()
- raise RuntimeError(
- f"grad_scale is too small, exiting: {cur_grad_scale}"
- )
-
- if params.batch_idx_train % params.log_interval == 0:
- cur_lr = max(scheduler.get_last_lr())
- cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
-
- logging.info(
- f"Epoch {params.cur_epoch}, batch {batch_idx}, "
- f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
- f"loss[{loss_info}], tot_loss[{tot_loss}], "
- f"cur_lr: {cur_lr:.2e}, "
- + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
- )
-
- if tb_writer is not None:
- tb_writer.add_scalar(
- "train/learning_rate", cur_lr, params.batch_idx_train
- )
- loss_info.write_summary(
- tb_writer, "train/current_", params.batch_idx_train
- )
- tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
- if params.use_fp16:
- tb_writer.add_scalar(
- "train/grad_scale", cur_grad_scale, params.batch_idx_train
- )
-
- loss_value = tot_loss["loss"]
- params.train_loss = loss_value
- if params.train_loss < params.best_train_loss:
- params.best_train_epoch = params.cur_epoch
- params.best_train_loss = params.train_loss
-
-
-def compute_validation_loss(
- params: AttributeDict,
- model: Union[nn.Module, DDP],
- tokenizer: TokenizerEmilia,
- valid_dl: torch.utils.data.DataLoader,
- world_size: int = 1,
-) -> MetricsTracker:
- """Run the validation process."""
-
- model.eval()
- device = model.device if isinstance(model, DDP) else next(model.parameters()).device
-
- # used to summary the stats over iterations
- tot_loss = MetricsTracker()
-
- for batch_idx, batch in enumerate(valid_dl):
- tokens, features, features_lens = prepare_input(
- params=params,
- batch=batch,
- device=device,
- tokenizer=tokenizer,
- return_tokens=True,
- return_feature=True,
- )
-
- loss, loss_info = compute_fbank_loss(
- params=params,
- model=model,
- features=features,
- features_lens=features_lens,
- tokens=tokens,
- is_training=False,
- )
- assert loss.requires_grad is False
- tot_loss = tot_loss + loss_info
-
- if world_size > 1:
- tot_loss.reduce(loss.device)
-
- loss_value = tot_loss["loss"]
- if loss_value < params.best_valid_loss:
- params.best_valid_epoch = params.cur_epoch
- params.best_valid_loss = loss_value
-
- return tot_loss
-
-
-def run(rank, world_size, args):
- """
- Args:
- rank:
- It is a value between 0 and `world_size-1`, which is
- passed automatically by `mp.spawn()` in :func:`main`.
- The node with rank 0 is responsible for saving checkpoint.
- world_size:
- Number of GPUs for DDP training.
- args:
- The return value of get_parser().parse_args()
- """
- params = get_params()
- params.update(vars(args))
-
- fix_random_seed(params.seed)
- if world_size > 1:
- setup_dist(rank, world_size, params.master_port)
-
- setup_logger(f"{params.exp_dir}/log/log-train")
- os.makedirs(f"{params.exp_dir}/fbank", exist_ok=True)
-
- logging.info("Training started")
-
- if args.tensorboard and rank == 0:
- tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
- else:
- tb_writer = None
-
- device = torch.device("cpu")
- if torch.cuda.is_available():
- device = torch.device("cuda", rank)
- logging.info(f"Device: {device}")
-
- if params.dataset == "emilia":
- tokenizer = TokenizerEmilia(
- token_file=params.token_file, token_type=params.token_type
- )
- elif params.dataset == "libritts":
- tokenizer = TokenizerLibriTTS(
- token_file=params.token_file, token_type=params.token_type
- )
- params.vocab_size = tokenizer.vocab_size
- params.pad_id = tokenizer.pad_id
-
- params.device = device
-
- logging.info(params)
-
- logging.info("About to create model")
-
- model = get_model(params)
- if params.checkpoint is not None:
- logging.info(f"Loading pre-trained model from {params.checkpoint}")
- _ = load_checkpoint(filename=params.checkpoint, model=model, strict=True)
- num_param = sum([p.numel() for p in model.parameters()])
- logging.info(f"Number of parameters : {num_param}")
-
- model_avg: Optional[nn.Module] = None
- if rank == 0:
- # model_avg is only used with rank 0
- model_avg = copy.deepcopy(model).to(torch.float64)
- assert params.start_epoch > 0, params.start_epoch
- if params.start_epoch > 1:
- checkpoints = resume_checkpoint(params=params, model=model, model_avg=model_avg)
-
- model = model.to(device)
- if world_size > 1:
- logging.info("Using DDP")
- model = DDP(model, device_ids=[rank], find_unused_parameters=True)
-
- optimizer = ScaledAdam(
- get_parameter_groups_with_lrs(
- model,
- lr=params.base_lr,
- include_names=True,
- ),
- lr=params.base_lr, # should have no effect
- clipping_scale=2.0,
- )
-
- assert params.lr_hours >= 0
- if params.lr_hours > 0:
- scheduler = Eden(optimizer, params.lr_batches, params.lr_hours)
- else:
- scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
-
- scaler = GradScaler("cuda", enabled=params.use_fp16)
-
- if params.start_epoch > 1 and checkpoints is not None:
- # load state_dict for optimizers
- if "optimizer" in checkpoints:
- logging.info("Loading optimizer state dict")
- optimizer.load_state_dict(checkpoints["optimizer"])
-
- # load state_dict for schedulers
- if "scheduler" in checkpoints:
- logging.info("Loading scheduler state dict")
- scheduler.load_state_dict(checkpoints["scheduler"])
-
- if "grad_scaler" in checkpoints:
- logging.info("Loading grad scaler state dict")
- scaler.load_state_dict(checkpoints["grad_scaler"])
-
- if params.print_diagnostics:
- opts = diagnostics.TensorDiagnosticOptions(
- 512
- ) # allow 4 megabytes per sub-module
- diagnostic = diagnostics.attach_diagnostics(model, opts)
-
- if params.inf_check:
- register_inf_check_hooks(model)
-
- def remove_short_and_long_utt_emilia(c: Cut):
- if c.duration < 1.0 or c.duration > 30.0:
- return False
- return True
-
- def remove_short_and_long_utt_libritts(c: Cut):
- if c.duration < 1.0 or c.duration > 20.0:
- return False
- return True
-
- datamodule = TtsDataModule(args)
- if params.dataset == "emilia":
- train_cuts = CutSet.mux(
- datamodule.train_emilia_EN_cuts(),
- datamodule.train_emilia_ZH_cuts(),
- weights=[46000, 49000],
- )
- train_cuts = train_cuts.filter(remove_short_and_long_utt_emilia)
- dev_cuts = CutSet.mux(
- datamodule.dev_emilia_EN_cuts(),
- datamodule.dev_emilia_ZH_cuts(),
- weights=[0.5, 0.5],
- )
- elif params.dataset == "libritts":
- train_cuts = datamodule.train_libritts_cuts()
- train_cuts = train_cuts.filter(remove_short_and_long_utt_libritts)
- dev_cuts = datamodule.dev_libritts_cuts()
-
- train_dl = datamodule.train_dataloaders(train_cuts)
-
- valid_dl = datamodule.dev_dataloaders(dev_cuts)
-
- for epoch in range(params.start_epoch, params.num_epochs + 1):
- logging.info(f"Start epoch {epoch}")
-
- if params.lr_hours == 0:
- scheduler.step_epoch(epoch - 1)
- fix_random_seed(params.seed + epoch - 1)
- train_dl.sampler.set_epoch(epoch - 1)
-
- params.cur_epoch = epoch
-
- if tb_writer is not None:
- tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
-
- train_one_epoch(
- params=params,
- model=model,
- model_avg=model_avg,
- tokenizer=tokenizer,
- optimizer=optimizer,
- scheduler=scheduler,
- train_dl=train_dl,
- valid_dl=valid_dl,
- scaler=scaler,
- tb_writer=tb_writer,
- world_size=world_size,
- rank=rank,
- )
-
- if params.print_diagnostics:
- diagnostic.print_diagnostics()
- break
-
- filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
- save_checkpoint(
- filename=filename,
- params=params,
- model=model,
- model_avg=model_avg,
- optimizer=optimizer,
- scheduler=scheduler,
- sampler=train_dl.sampler,
- scaler=scaler,
- rank=rank,
- )
-
- if rank == 0:
- if params.best_train_epoch == params.cur_epoch:
- best_train_filename = params.exp_dir / "best-train-loss.pt"
- copyfile(src=filename, dst=best_train_filename)
-
- if params.best_valid_epoch == params.cur_epoch:
- best_valid_filename = params.exp_dir / "best-valid-loss.pt"
- copyfile(src=filename, dst=best_valid_filename)
-
- logging.info("Done!")
-
- if world_size > 1:
- torch.distributed.barrier()
- cleanup_dist()
-
-
-def main():
- parser = get_parser()
- TtsDataModule.add_arguments(parser)
- args = parser.parse_args()
- args.exp_dir = Path(args.exp_dir)
-
- world_size = args.world_size
- assert world_size >= 1
- if world_size > 1:
- mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
- else:
- run(rank=0, world_size=1, args=args)
-
-
-if __name__ == "__main__":
- torch.set_num_threads(1)
- torch.set_num_interop_threads(1)
- main()
diff --git a/egs/zipvoice/zipvoice/tts_datamodule.py b/egs/zipvoice/zipvoice/tts_datamodule.py
deleted file mode 100644
index 972c700f7..000000000
--- a/egs/zipvoice/zipvoice/tts_datamodule.py
+++ /dev/null
@@ -1,456 +0,0 @@
-# Copyright 2021 Piotr Żelasko
-# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
-# Zengwei Yao,
-# Zengrui Jin,
-# Han Zhu,
-# Wei Kang)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import argparse
-import logging
-from functools import lru_cache
-from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional, Sequence, Union
-
-import torch
-from feature import TorchAudioFbank, TorchAudioFbankConfig
-from lhotse import CutSet, load_manifest_lazy, validate
-from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
- DynamicBucketingSampler,
- PrecomputedFeatures,
- SimpleCutSampler,
-)
-from lhotse.dataset.collation import collate_audio
-from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
- BatchIO,
- OnTheFlyFeatures,
-)
-from lhotse.utils import fix_random_seed, ifnone
-from torch.utils.data import DataLoader
-
-from icefall.utils import str2bool
-
-
-class _SeedWorkers:
- def __init__(self, seed: int):
- self.seed = seed
-
- def __call__(self, worker_id: int):
- fix_random_seed(self.seed + worker_id)
-
-
-SAMPLING_RATE = 24000
-
-
-class TtsDataModule:
- """
- DataModule for tts experiments.
- It assumes there is always one train and valid dataloader,
- but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
- and test-other).
-
- It contains all the common data pipeline modules used in ASR
- experiments, e.g.:
- - dynamic batch size,
- - bucketing samplers,
- - cut concatenation,
- - on-the-fly feature extraction
-
- This class should be derived for specific corpora used in ASR tasks.
- """
-
- def __init__(self, args: argparse.Namespace):
- self.args = args
-
- @classmethod
- def add_arguments(cls, parser: argparse.ArgumentParser):
- group = parser.add_argument_group(
- title="TTS data related options",
- description="These options are used for the preparation of "
- "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
- "effective batch sizes, sampling strategies, applied data "
- "augmentations, etc.",
- )
- group.add_argument(
- "--manifest-dir",
- type=Path,
- default=Path("data/fbank_emilia"),
- help="Path to directory with train/valid/test cuts.",
- )
- group.add_argument(
- "--max-duration",
- type=int,
- default=200.0,
- help="Maximum pooled recordings duration (seconds) in a "
- "single batch. You can reduce it if it causes CUDA OOM.",
- )
- group.add_argument(
- "--bucketing-sampler",
- type=str2bool,
- default=True,
- help="When enabled, the batches will come from buckets of "
- "similar duration (saves padding frames).",
- )
- group.add_argument(
- "--num-buckets",
- type=int,
- default=100,
- help="The number of buckets for the DynamicBucketingSampler"
- "(you might want to increase it for larger datasets).",
- )
-
- group.add_argument(
- "--on-the-fly-feats",
- type=str2bool,
- default=False,
- help="When enabled, use on-the-fly cut mixing and feature "
- "extraction. Will drop existing precomputed feature manifests "
- "if available.",
- )
- group.add_argument(
- "--shuffle",
- type=str2bool,
- default=True,
- help="When enabled (=default), the examples will be "
- "shuffled for each epoch.",
- )
- group.add_argument(
- "--drop-last",
- type=str2bool,
- default=True,
- help="Whether to drop last batch. Used by sampler.",
- )
- group.add_argument(
- "--return-cuts",
- type=str2bool,
- default=True,
- help="When enabled, each batch will have the "
- "field: batch['cut'] with the cuts that "
- "were used to construct it.",
- )
- group.add_argument(
- "--num-workers",
- type=int,
- default=8,
- help="The number of training dataloader workers that "
- "collect the batches.",
- )
-
- group.add_argument(
- "--input-strategy",
- type=str,
- default="PrecomputedFeatures",
- help="AudioSamples or PrecomputedFeatures",
- )
-
- def train_dataloaders(
- self,
- cuts_train: CutSet,
- sampler_state_dict: Optional[Dict[str, Any]] = None,
- ) -> DataLoader:
- """
- Args:
- cuts_train:
- CutSet for training.
- sampler_state_dict:
- The state dict for the training sampler.
- """
- logging.info("About to create train dataset")
- train = SpeechSynthesisDataset(
- return_text=True,
- return_tokens=True,
- return_spk_ids=True,
- feature_input_strategy=eval(self.args.input_strategy)(),
- return_cuts=self.args.return_cuts,
- )
-
- if self.args.on_the_fly_feats:
- sampling_rate = SAMPLING_RATE
- config = TorchAudioFbankConfig(
- sampling_rate=sampling_rate,
- n_mels=100,
- n_fft=1024,
- hop_length=256,
- )
- train = SpeechSynthesisDataset(
- return_text=True,
- return_tokens=True,
- return_spk_ids=True,
- feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)),
- return_cuts=self.args.return_cuts,
- )
-
- if self.args.bucketing_sampler:
- logging.info("Using DynamicBucketingSampler.")
- train_sampler = DynamicBucketingSampler(
- cuts_train,
- max_duration=self.args.max_duration,
- shuffle=self.args.shuffle,
- num_buckets=self.args.num_buckets,
- buffer_size=self.args.num_buckets * 2000,
- shuffle_buffer_size=self.args.num_buckets * 5000,
- drop_last=self.args.drop_last,
- )
- else:
- logging.info("Using SimpleCutSampler.")
- train_sampler = SimpleCutSampler(
- cuts_train,
- max_duration=self.args.max_duration,
- shuffle=self.args.shuffle,
- )
- logging.info("About to create train dataloader")
-
- if sampler_state_dict is not None:
- logging.info("Loading sampler state dict")
- train_sampler.load_state_dict(sampler_state_dict)
-
- # 'seed' is derived from the current random state, which will have
- # previously been set in the main process.
- seed = torch.randint(0, 100000, ()).item()
- worker_init_fn = _SeedWorkers(seed)
-
- train_dl = DataLoader(
- train,
- sampler=train_sampler,
- batch_size=None,
- num_workers=self.args.num_workers,
- persistent_workers=False,
- worker_init_fn=worker_init_fn,
- )
-
- return train_dl
-
- def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
- logging.info("About to create dev dataset")
- if self.args.on_the_fly_feats:
- sampling_rate = SAMPLING_RATE
- config = TorchAudioFbankConfig(
- sampling_rate=sampling_rate,
- n_mels=100,
- n_fft=1024,
- hop_length=256,
- )
- validate = SpeechSynthesisDataset(
- return_text=True,
- return_tokens=True,
- return_spk_ids=True,
- feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)),
- return_cuts=self.args.return_cuts,
- )
- else:
- validate = SpeechSynthesisDataset(
- return_text=True,
- return_tokens=True,
- return_spk_ids=True,
- feature_input_strategy=eval(self.args.input_strategy)(),
- return_cuts=self.args.return_cuts,
- )
- dev_sampler = DynamicBucketingSampler(
- cuts_valid,
- max_duration=self.args.max_duration,
- shuffle=False,
- )
- logging.info("About to create valid dataloader")
- dev_dl = DataLoader(
- validate,
- sampler=dev_sampler,
- batch_size=None,
- num_workers=2,
- persistent_workers=False,
- )
-
- return dev_dl
-
- def test_dataloaders(self, cuts: CutSet) -> DataLoader:
- logging.info("About to create test dataset")
- if self.args.on_the_fly_feats:
- sampling_rate = SAMPLING_RATE
- config = TorchAudioFbankConfig(
- sampling_rate=sampling_rate,
- n_mels=100,
- n_fft=1024,
- hop_length=256,
- )
- test = SpeechSynthesisDataset(
- return_text=True,
- return_tokens=True,
- return_spk_ids=True,
- feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)),
- return_cuts=self.args.return_cuts,
- return_audio=True,
- )
- else:
- test = SpeechSynthesisDataset(
- return_text=True,
- return_tokens=True,
- return_spk_ids=True,
- feature_input_strategy=eval(self.args.input_strategy)(),
- return_cuts=self.args.return_cuts,
- return_audio=True,
- )
- test_sampler = DynamicBucketingSampler(
- cuts,
- max_duration=self.args.max_duration,
- shuffle=False,
- )
- logging.info("About to create test dataloader")
- test_dl = DataLoader(
- test,
- batch_size=None,
- sampler=test_sampler,
- num_workers=self.args.num_workers,
- )
- return test_dl
-
- @lru_cache()
- def train_emilia_EN_cuts(self) -> CutSet:
- logging.info("About to get train the EN subset")
- return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_EN.jsonl.gz")
-
- @lru_cache()
- def train_emilia_ZH_cuts(self) -> CutSet:
- logging.info("About to get train the ZH subset")
- return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_ZH.jsonl.gz")
-
- @lru_cache()
- def dev_emilia_EN_cuts(self) -> CutSet:
- logging.info("About to get dev the EN subset")
- return load_manifest_lazy(
- self.args.manifest_dir / "emilia_cuts_EN-dev.jsonl.gz"
- )
-
- @lru_cache()
- def dev_emilia_ZH_cuts(self) -> CutSet:
- logging.info("About to get dev the ZH subset")
- return load_manifest_lazy(
- self.args.manifest_dir / "emilia_cuts_ZH-dev.jsonl.gz"
- )
-
- @lru_cache()
- def train_libritts_cuts(self) -> CutSet:
- logging.info(
- "About to get the shuffled train-clean-100, \
- train-clean-360 and train-other-500 cuts"
- )
- return load_manifest_lazy(
- self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz"
- )
-
- @lru_cache()
- def dev_libritts_cuts(self) -> CutSet:
- logging.info("About to get dev-clean cuts")
- return load_manifest_lazy(
- self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
- )
-
-
-class SpeechSynthesisDataset(torch.utils.data.Dataset):
- """
- The PyTorch Dataset for the speech synthesis task.
- Each item in this dataset is a dict of:
-
- .. code-block::
-
- {
- 'audio': (B x NumSamples) float tensor
- 'features': (B x NumFrames x NumFeatures) float tensor
- 'audio_lens': (B, ) int tensor
- 'features_lens': (B, ) int tensor
- 'text': List[str] of len B # when return_text=True
- 'tokens': List[List[str]] # when return_tokens=True
- 'speakers': List[str] of len B # when return_spk_ids=True
- 'cut': List of Cuts # when return_cuts=True
- }
- """
-
- def __init__(
- self,
- cut_transforms: List[Callable[[CutSet], CutSet]] = None,
- feature_input_strategy: BatchIO = PrecomputedFeatures(),
- feature_transforms: Union[Sequence[Callable], Callable] = None,
- return_text: bool = True,
- return_tokens: bool = False,
- return_spk_ids: bool = False,
- return_cuts: bool = False,
- return_audio: bool = False,
- ) -> None:
- super().__init__()
-
- self.cut_transforms = ifnone(cut_transforms, [])
- self.feature_input_strategy = feature_input_strategy
-
- self.return_text = return_text
- self.return_tokens = return_tokens
- self.return_spk_ids = return_spk_ids
- self.return_cuts = return_cuts
- self.return_audio = return_audio
-
- if feature_transforms is None:
- feature_transforms = []
- elif not isinstance(feature_transforms, Sequence):
- feature_transforms = [feature_transforms]
-
- assert all(
- isinstance(transform, Callable) for transform in feature_transforms
- ), "Feature transforms must be Callable"
- self.feature_transforms = feature_transforms
-
- def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
- validate_for_tts(cuts)
-
- for transform in self.cut_transforms:
- cuts = transform(cuts)
-
- features, features_lens = self.feature_input_strategy(cuts)
-
- for transform in self.feature_transforms:
- features = transform(features)
-
- batch = {
- "features": features,
- "features_lens": features_lens,
- }
-
- if self.return_audio:
- audio, audio_lens = collate_audio(cuts)
- batch["audio"] = audio
- batch["audio_lens"] = audio_lens
-
- if self.return_text:
- # use normalized text
- text = [cut.supervisions[0].normalized_text for cut in cuts]
- batch["text"] = text
-
- if self.return_tokens:
- tokens = [cut.tokens for cut in cuts]
- batch["tokens"] = tokens
-
- if self.return_spk_ids:
- batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
-
- if self.return_cuts:
- batch["cut"] = [cut for cut in cuts]
-
- return batch
-
-
-def validate_for_tts(cuts: CutSet) -> None:
- validate(cuts)
- for cut in cuts:
- assert (
- len(cut.supervisions) == 1
- ), "Only the Cuts with single supervision are supported."
diff --git a/egs/zipvoice/zipvoice/utils.py b/egs/zipvoice/zipvoice/utils.py
deleted file mode 100644
index 4092d0ae4..000000000
--- a/egs/zipvoice/zipvoice/utils.py
+++ /dev/null
@@ -1,219 +0,0 @@
-from typing import Any, List, Optional, Tuple, Union
-
-import torch
-from torch import nn
-from torch.nn.parallel import DistributedDataParallel as DDP
-
-
-class AttributeDict(dict):
- def __getattr__(self, key):
- if key in self:
- return self[key]
- raise AttributeError(f"No such attribute '{key}'")
-
- def __setattr__(self, key, value):
- self[key] = value
-
- def __delattr__(self, key):
- if key in self:
- del self[key]
- return
- raise AttributeError(f"No such attribute '{key}'")
-
-
-def prepare_input(
- params: AttributeDict,
- batch: dict,
- device: torch.device,
- tokenizer: Optional[Any] = None,
- return_tokens: bool = False,
- return_feature: bool = False,
- return_audio: bool = False,
- return_prompt: bool = False,
-):
- """
- Parse the features and targets of the current batch.
- Args:
- params:
- It is returned by :func:`get_params`.
- batch:
- It is the return value from iterating
- `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
- for the format of the `batch`.
- sp:
- Used to convert text to bpe tokens.
- device:
- The device of Tensor.
- """
- return_list = []
-
- if return_tokens:
- assert tokenizer is not None
-
- if params.token_type == "phone":
- tokens = tokenizer.tokens_to_token_ids(batch["tokens"])
- else:
- tokens = tokenizer.texts_to_token_ids(batch["text"])
- return_list += [tokens]
-
- if return_feature:
- features = batch["features"].to(device)
- features_lens = batch["features_lens"].to(device)
- return_list += [features * params.feat_scale, features_lens]
-
- if return_audio:
- return_list += [batch["audio"], batch["audio_lens"]]
-
- if return_prompt:
- if return_tokens:
- if params.token_type == "phone":
- prompt_tokens = tokenizer.tokens_to_token_ids(batch["prompt"]["tokens"])
- else:
- prompt_tokens = tokenizer.texts_to_token_ids(batch["prompt"]["text"])
- return_list += [prompt_tokens]
- if return_feature:
- prompt_features = batch["prompt"]["features"].to(device)
- prompt_features_lens = batch["prompt"]["features_lens"].to(device)
- return_list += [prompt_features * params.feat_scale, prompt_features_lens]
- if return_audio:
- return_list += [batch["prompt"]["audio"], batch["prompt"]["audio_lens"]]
-
- return return_list
-
-
-def prepare_avg_tokens_durations(features_lens, tokens_lens):
- tokens_durations = []
- for i in range(len(features_lens)):
- utt_duration = features_lens[i]
- avg_token_duration = utt_duration // tokens_lens[i]
- tokens_durations.append([avg_token_duration] * tokens_lens[i])
- return tokens_durations
-
-
-def pad_labels(y: List[List[int]], pad_id: int, device: torch.device):
- """
- Pad the transcripts to the same length with zeros.
-
- Args:
- y: the transcripts, which is a list of a list
-
- Returns:
- Return a Tensor of padded transcripts.
- """
- y = [l + [pad_id] for l in y]
- length = max([len(l) for l in y])
- y = [l + [pad_id] * (length - len(l)) for l in y]
- return torch.tensor(y, dtype=torch.int64, device=device)
-
-
-def get_tokens_index(durations: List[List[int]], num_frames: int) -> torch.Tensor:
- """
- Gets position in the transcript for each frame, i.e. the position
- in the symbol-sequence to look up.
-
- Args:
- durations:
- Duration of each token in transcripts.
- num_frames:
- The maximum frame length of the current batch.
-
- Returns:
- Return a Tensor of shape (batch_size, num_frames)
- """
- durations = [x + [num_frames - sum(x)] for x in durations]
- batch_size = len(durations)
- ans = torch.zeros(batch_size, num_frames, dtype=torch.int64)
- for b in range(batch_size):
- this_dur = durations[b]
- cur_frame = 0
- for i, d in enumerate(this_dur):
- ans[b, cur_frame : cur_frame + d] = i
- cur_frame += d
- assert cur_frame == num_frames, (cur_frame, num_frames)
- return ans
-
-
-def to_int_tuple(s: str):
- return tuple(map(int, s.split(",")))
-
-
-def get_adjusted_batch_count(params: AttributeDict) -> float:
- # returns the number of batches we would have used so far if we had used the reference
- # duration. This is for purposes of set_batch_count().
- return (
- params.batch_idx_train
- * (params.max_duration * params.world_size)
- / params.ref_duration
- )
-
-
-def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
- if isinstance(model, DDP):
- # get underlying nn.Module
- model = model.module
- for name, module in model.named_modules():
- if hasattr(module, "batch_count"):
- module.batch_count = batch_count
- if hasattr(module, "name"):
- module.name = name
-
-
-def condition_time_mask(
- features_lens: torch.Tensor, mask_percent: Tuple[float, float], max_len: int = 0
-) -> torch.Tensor:
- """
- Apply Time masking.
- Args:
- features_lens:
- input tensor of shape ``(B)``
- mask_size:
- the width size for masking.
- max_len:
- the maximum length of the mask.
- Returns:
- Return a 2-D bool tensor (B, T), where masked positions
- are filled with `True` and non-masked positions are
- filled with `False`.
- """
- mask_size = (
- torch.zeros_like(features_lens, dtype=torch.float32).uniform_(*mask_percent)
- * features_lens
- ).to(torch.int64)
- mask_starts = (
- torch.rand_like(mask_size, dtype=torch.float32) * (features_lens - mask_size)
- ).to(torch.int64)
- mask_ends = mask_starts + mask_size
- max_len = max(max_len, features_lens.max())
- seq_range = torch.arange(0, max_len, device=features_lens.device)
- mask = (seq_range[None, :] >= mask_starts[:, None]) & (
- seq_range[None, :] < mask_ends[:, None]
- )
- return mask
-
-
-def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
- """
- Args:
- lengths:
- A 1-D tensor containing sentence lengths.
- max_len:
- The length of masks.
- Returns:
- Return a 2-D bool tensor, where masked positions
- are filled with `True` and non-masked positions are
- filled with `False`.
-
- >>> lengths = torch.tensor([1, 3, 2, 5])
- >>> make_pad_mask(lengths)
- tensor([[False, True, True, True, True],
- [False, False, False, True, True],
- [False, False, True, True, True],
- [False, False, False, False, False]])
- """
- assert lengths.ndim == 1, lengths.ndim
- max_len = max(max_len, lengths.max())
- n = lengths.size(0)
- seq_range = torch.arange(0, max_len, device=lengths.device)
- expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
-
- return expaned_lengths >= lengths.unsqueeze(-1)
diff --git a/egs/zipvoice/zipvoice/zipformer.py b/egs/zipvoice/zipvoice/zipformer.py
deleted file mode 100644
index 190191cbb..000000000
--- a/egs/zipvoice/zipvoice/zipformer.py
+++ /dev/null
@@ -1,1648 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey,
-# Zengwei Yao,
-# Wei Kang
-# Han Zhu)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import copy
-import logging
-import math
-import random
-from typing import Optional, Tuple, Union
-
-import torch
-from scaling import (
- Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
-)
-from scaling import (
- ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
-)
-from scaling import (
- ActivationDropoutAndLinear,
- Balancer,
- BiasNorm,
- Dropout2,
- FloatLike,
- ScheduledFloat,
- SwooshR,
- Whiten,
- limit_param_value,
- penalize_abs_values_gt,
- softmax,
-)
-from torch import Tensor, nn
-
-
-def timestep_embedding(timesteps, dim, max_period=10000):
- """Create sinusoidal timestep embeddings.
-
- :param timesteps: shape of (N) or (N, T)
- :param dim: the dimension of the output.
- :param max_period: controls the minimum frequency of the embeddings.
- :return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim)
- """
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period)
- * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
- / half
- )
-
- if timesteps.dim() == 2:
- timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N)
-
- args = timesteps[..., None].float() * freqs[None]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1)
- return embedding
-
-
-class TTSZipformer(nn.Module):
- """
- Args:
-
- Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length
- as downsampling_factor if they are single ints or one-element tuples. The length of
- downsampling_factor defines the number of stacks.
-
- downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
- Note: this is in addition to the downsampling factor of 2 that is applied in
- the frontend (self.encoder_embed).
- encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
- encoder stack.
- num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
- query_head_dim (int or Tuple[int]): dimension of query and key per attention
- head: per stack, if a tuple..
- pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per
- attention head
- value_head_dim (int or Tuple[int]): dimension of value in each attention head
- num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
- Must be at least 4.
- feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
- cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
-
- pos_dim (int): the dimension of each positional-encoding vector prior to projection,
- e.g. 128.
-
- dropout (float): dropout rate
- warmup_batches (float): number of batches to warm up over; this controls
- dropout of encoder layers.
- use_time_embed: (bool): if True, do not take time embedding as additional input.
- time_embed_dim: (int): the dimension of the time embedding.
- """
-
- def __init__(
- self,
- in_dim: int,
- out_dim: int,
- downsampling_factor: Tuple[int] = (2, 4),
- num_encoder_layers: Union[int, Tuple[int]] = 4,
- cnn_module_kernel: Union[int, Tuple[int]] = 31,
- encoder_dim: int = 384,
- query_head_dim: int = 24,
- pos_head_dim: int = 4,
- value_head_dim: int = 12,
- num_heads: int = 8,
- feedforward_dim: int = 1536,
- pos_dim: int = 192,
- dropout: FloatLike = None, # see code below for default
- warmup_batches: float = 4000.0,
- use_time_embed: bool = True,
- time_embed_dim: int = 192,
- use_guidance_scale_embed: bool = False,
- guidance_scale_embed_dim: int = 192,
- use_conv: bool = True,
- ) -> None:
- super(TTSZipformer, self).__init__()
-
- if dropout is None:
- dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
-
- def _to_tuple(x):
- """Converts a single int or a 1-tuple of an int to a tuple with the same length
- as downsampling_factor"""
- if isinstance(x, int):
- x = (x,)
- if len(x) == 1:
- x = x * len(downsampling_factor)
- else:
- assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
- return x
-
- def _assert_downsampling_factor(factors):
- """assert downsampling_factor follows u-net style"""
- assert factors[0] == 1 and factors[-1] == 1
-
- for i in range(1, len(factors) // 2 + 1):
- assert factors[i] == factors[i - 1] * 2
-
- for i in range(len(factors) // 2 + 1, len(factors)):
- assert factors[i] * 2 == factors[i - 1]
-
- _assert_downsampling_factor(downsampling_factor)
- self.downsampling_factor = downsampling_factor # tuple
- num_encoder_layers = _to_tuple(num_encoder_layers)
- self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
- self.encoder_dim = encoder_dim
- self.num_encoder_layers = num_encoder_layers
- self.query_head_dim = query_head_dim
- self.value_head_dim = value_head_dim
- self.num_heads = num_heads
-
- self.use_time_embed = use_time_embed
- self.use_guidance_scale_embed = use_guidance_scale_embed
-
- self.time_embed_dim = time_embed_dim
- if self.use_time_embed:
- assert time_embed_dim != -1
- else:
- time_embed_dim = -1
- self.guidance_scale_embed_dim = guidance_scale_embed_dim
-
- self.in_proj = nn.Linear(in_dim, encoder_dim)
- self.out_proj = nn.Linear(encoder_dim, out_dim)
-
- # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
- encoders = []
-
- num_encoders = len(downsampling_factor)
- for i in range(num_encoders):
- encoder_layer = Zipformer2EncoderLayer(
- embed_dim=encoder_dim,
- pos_dim=pos_dim,
- num_heads=num_heads,
- query_head_dim=query_head_dim,
- pos_head_dim=pos_head_dim,
- value_head_dim=value_head_dim,
- feedforward_dim=feedforward_dim,
- use_conv=use_conv,
- cnn_module_kernel=cnn_module_kernel[i],
- dropout=dropout,
- )
-
- # For the segment of the warmup period, we let the Conv2dSubsampling
- # layer learn something. Then we start to warm up the other encoders.
- encoder = Zipformer2Encoder(
- encoder_layer,
- num_encoder_layers[i],
- embed_dim=encoder_dim,
- time_embed_dim=time_embed_dim,
- pos_dim=pos_dim,
- warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
- warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
- final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
- )
-
- if downsampling_factor[i] != 1:
- encoder = DownsampledZipformer2Encoder(
- encoder,
- dim=encoder_dim,
- downsample=downsampling_factor[i],
- )
-
- encoders.append(encoder)
-
- self.encoders = nn.ModuleList(encoders)
- if self.use_time_embed:
- self.time_embed = nn.Sequential(
- nn.Linear(time_embed_dim, time_embed_dim * 2),
- SwooshR(),
- nn.Linear(time_embed_dim * 2, time_embed_dim),
- )
- else:
- self.time_embed = None
-
- if self.use_guidance_scale_embed:
- self.guidance_scale_embed = ScaledLinear(
- guidance_scale_embed_dim, time_embed_dim, bias=False, initial_scale=0.1
- )
- else:
- self.guidance_scale_embed = None
-
- def forward(
- self,
- x: Tensor,
- t: Optional[Tensor] = None,
- padding_mask: Optional[Tensor] = None,
- guidance_scale: Optional[Tensor] = None,
- ) -> Tuple[Tensor, Tensor]:
- """
- Args:
- x:
- The input tensor. Its shape is (batch_size, seq_len, feature_dim).
- t:
- A t tensor of shape (batch_size,) or (batch_size, seq_len)
- padding_mask:
- The mask for padding, of shape (batch_size, seq_len); True means
- masked position. May be None.
- Returns:
- Return the output embeddings. its shape is (batch_size, output_seq_len, encoder_dim)
- """
- x = x.permute(1, 0, 2)
- x = self.in_proj(x)
-
- if t is not None:
- assert t.dim() == 1 or t.dim() == 2, t.shape
- time_emb = timestep_embedding(t, self.time_embed_dim)
- if guidance_scale is not None:
- assert (
- guidance_scale.dim() == 1 or guidance_scale.dim() == 2
- ), guidance_scale.shape
- guidance_scale_emb = self.guidance_scale_embed(
- timestep_embedding(guidance_scale, self.guidance_scale_embed_dim)
- )
- time_emb = time_emb + guidance_scale_emb
- time_emb = self.time_embed(time_emb)
- else:
- time_emb = None
-
- attn_mask = None
-
- for i, module in enumerate(self.encoders):
- x = module(
- x,
- time_emb=time_emb,
- src_key_padding_mask=padding_mask,
- attn_mask=attn_mask,
- )
- x = self.out_proj(x)
- x = x.permute(1, 0, 2)
- return x
-
-
-def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
- return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x)
-
-
-class Zipformer2EncoderLayer(nn.Module):
- """
- Args:
- embed_dim: the number of expected features in the input (required).
- nhead: the number of heads in the multiheadattention models (required).
- feedforward_dim: the dimension of the feedforward network model (required).
- dropout: the dropout value (default=0.1).
- cnn_module_kernel (int): Kernel size of convolution module (default=31).
-
- Examples::
- >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
- >>> src = torch.rand(10, 32, 512)
- >>> pos_emb = torch.rand(32, 19, 512)
- >>> out = encoder_layer(src, pos_emb)
- """
-
- def __init__(
- self,
- embed_dim: int,
- pos_dim: int,
- num_heads: int,
- query_head_dim: int,
- pos_head_dim: int,
- value_head_dim: int,
- feedforward_dim: int,
- dropout: FloatLike = 0.1,
- cnn_module_kernel: int = 31,
- use_conv: bool = True,
- attention_skip_rate: FloatLike = ScheduledFloat(
- (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
- ),
- conv_skip_rate: FloatLike = ScheduledFloat(
- (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
- ),
- const_attention_rate: FloatLike = ScheduledFloat(
- (0.0, 0.25), (4000.0, 0.025), default=0
- ),
- ff2_skip_rate: FloatLike = ScheduledFloat(
- (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
- ),
- ff3_skip_rate: FloatLike = ScheduledFloat(
- (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
- ),
- bypass_skip_rate: FloatLike = ScheduledFloat(
- (0.0, 0.5), (4000.0, 0.02), default=0
- ),
- ) -> None:
- super(Zipformer2EncoderLayer, self).__init__()
- self.embed_dim = embed_dim
-
- # self.bypass implements layer skipping as well as bypass; see its default values.
- self.bypass = BypassModule(
- embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
- )
- # bypass_mid is bypass used in the middle of the layer.
- self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
-
- # skip probability for dynamic modules (meaning: anything but feedforward).
- self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
- # an additional skip probability that applies to ConvModule to stop it from
- # contributing too much early on.
- self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
-
- # ff2_skip_rate is to prevent the ff2 module from having output that's too big
- # compared to its residual.
- self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
- self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
-
- self.const_attention_rate = copy.deepcopy(const_attention_rate)
-
- self.self_attn_weights = RelPositionMultiheadAttentionWeights(
- embed_dim,
- pos_dim=pos_dim,
- num_heads=num_heads,
- query_head_dim=query_head_dim,
- pos_head_dim=pos_head_dim,
- dropout=0.0,
- )
-
- self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
-
- self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)
-
- self.feed_forward1 = FeedforwardModule(
- embed_dim, (feedforward_dim * 3) // 4, dropout
- )
-
- self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
-
- self.feed_forward3 = FeedforwardModule(
- embed_dim, (feedforward_dim * 5) // 4, dropout
- )
-
- self.nonlin_attention = NonlinAttention(
- embed_dim, hidden_channels=3 * embed_dim // 4
- )
-
- self.use_conv = use_conv
-
- if self.use_conv:
- self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel)
-
- self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel)
-
- self.norm = BiasNorm(embed_dim)
-
- self.balancer1 = Balancer(
- embed_dim,
- channel_dim=-1,
- min_positive=0.45,
- max_positive=0.55,
- min_abs=0.2,
- max_abs=4.0,
- )
-
- # balancer for output of NonlinAttentionModule
- self.balancer_na = Balancer(
- embed_dim,
- channel_dim=-1,
- min_positive=0.3,
- max_positive=0.7,
- min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
- prob=0.05, # out of concern for memory usage
- )
-
- # balancer for output of feedforward2, prevent it from staying too
- # small. give this a very small probability, even at the start of
- # training, it's to fix a rare problem and it's OK to fix it slowly.
- self.balancer_ff2 = Balancer(
- embed_dim,
- channel_dim=-1,
- min_positive=0.3,
- max_positive=0.7,
- min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
- max_abs=2.0,
- prob=0.05,
- )
-
- self.balancer_ff3 = Balancer(
- embed_dim,
- channel_dim=-1,
- min_positive=0.3,
- max_positive=0.7,
- min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
- max_abs=4.0,
- prob=0.05,
- )
-
- self.whiten = Whiten(
- num_groups=1,
- whitening_limit=_whitening_schedule(4.0, ratio=3.0),
- prob=(0.025, 0.25),
- grad_scale=0.01,
- )
-
- self.balancer2 = Balancer(
- embed_dim,
- channel_dim=-1,
- min_positive=0.45,
- max_positive=0.55,
- min_abs=0.1,
- max_abs=4.0,
- )
-
- def get_sequence_dropout_mask(
- self, x: Tensor, dropout_rate: float
- ) -> Optional[Tensor]:
- if (
- dropout_rate == 0.0
- or not self.training
- or torch.jit.is_scripting()
- or torch.jit.is_tracing()
- ):
- return None
- batch_size = x.shape[1]
- mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
- return mask
-
- def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
- """
- Apply sequence-level dropout to x.
- x shape: (seq_len, batch_size, embed_dim)
- """
- dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
- if dropout_mask is None:
- return x
- else:
- return x * dropout_mask
-
- def forward(
- self,
- src: Tensor,
- pos_emb: Tensor,
- time_emb: Optional[Tensor] = None,
- attn_mask: Optional[Tensor] = None,
- src_key_padding_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """
- Pass the input through the encoder layer.
- Args:
- src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
- pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
- time_emb: the embedding representing the current timestep: shape (batch_size, embedding_dim)
- or (seq_len, batch_size, embedding_dim) .
- attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
- interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
- True means masked position. May be None.
- src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
- masked position. May be None.
-
- Returns:
- A tensor which has the same shape as src
- """
- src_orig = src
-
- # dropout rate for non-feedforward submodules
- if torch.jit.is_scripting() or torch.jit.is_tracing():
- attention_skip_rate = 0.0
- else:
- attention_skip_rate = (
- float(self.attention_skip_rate) if self.training else 0.0
- )
-
- # attn_weights: (num_heads, batch_size, seq_len, seq_len)
- attn_weights = self.self_attn_weights(
- src,
- pos_emb=pos_emb,
- attn_mask=attn_mask,
- key_padding_mask=src_key_padding_mask,
- )
- if time_emb is not None:
-
- src = src + time_emb
-
- src = src + self.feed_forward1(src)
-
- self_attn_dropout_mask = self.get_sequence_dropout_mask(
- src, attention_skip_rate
- )
-
- selected_attn_weights = attn_weights[0:1]
- if torch.jit.is_scripting() or torch.jit.is_tracing():
- pass
- elif self.training and random.random() < float(self.const_attention_rate):
- # Make attention weights constant. The intention is to
- # encourage these modules to do something similar to an
- # averaging-over-time operation.
- # only need the mask, can just use the 1st one and expand later
- selected_attn_weights = selected_attn_weights[0:1]
- selected_attn_weights = (selected_attn_weights > 0.0).to(
- selected_attn_weights.dtype
- )
- selected_attn_weights = selected_attn_weights * (
- 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
- )
-
- na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
-
- src = src + (
- na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
- )
-
- self_attn = self.self_attn1(src, attn_weights)
-
- src = src + (
- self_attn
- if self_attn_dropout_mask is None
- else self_attn * self_attn_dropout_mask
- )
-
- if self.use_conv:
- if torch.jit.is_scripting() or torch.jit.is_tracing():
- conv_skip_rate = 0.0
- else:
- conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
-
- if time_emb is not None:
- src = src + time_emb
-
- src = src + self.sequence_dropout(
- self.conv_module1(
- src,
- src_key_padding_mask=src_key_padding_mask,
- ),
- conv_skip_rate,
- )
-
- if torch.jit.is_scripting() or torch.jit.is_tracing():
- ff2_skip_rate = 0.0
- else:
- ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
- src = src + self.sequence_dropout(
- self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
- )
-
- # bypass in the middle of the layer.
- src = self.bypass_mid(src_orig, src)
-
- self_attn = self.self_attn2(src, attn_weights)
-
- src = src + (
- self_attn
- if self_attn_dropout_mask is None
- else self_attn * self_attn_dropout_mask
- )
-
- if self.use_conv:
-
- if torch.jit.is_scripting() or torch.jit.is_tracing():
- conv_skip_rate = 0.0
- else:
- conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
-
- if time_emb is not None:
- src = src + time_emb
-
- src = src + self.sequence_dropout(
- self.conv_module2(
- src,
- src_key_padding_mask=src_key_padding_mask,
- ),
- conv_skip_rate,
- )
-
- if torch.jit.is_scripting() or torch.jit.is_tracing():
- ff3_skip_rate = 0.0
- else:
- ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
- src = src + self.sequence_dropout(
- self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
- )
-
- src = self.balancer1(src)
- src = self.norm(src)
-
- src = self.bypass(src_orig, src)
-
- src = self.balancer2(src)
- src = self.whiten(src)
-
- return src
-
-
-class Zipformer2Encoder(nn.Module):
- r"""Zipformer2Encoder is a stack of N encoder layers
-
- Args:
- encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
- num_layers: the number of sub-encoder-layers in the encoder (required).
- pos_dim: the dimension for the relative positional encoding
-
- Examples::
- >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
- >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
- >>> src = torch.rand(10, 32, 512)
- >>> out = zipformer_encoder(src)
- """
-
- def __init__(
- self,
- encoder_layer: nn.Module,
- num_layers: int,
- embed_dim: int,
- time_embed_dim: int,
- pos_dim: int,
- warmup_begin: float,
- warmup_end: float,
- initial_layerdrop_rate: float = 0.5,
- final_layerdrop_rate: float = 0.05,
- ) -> None:
- super().__init__()
- self.encoder_pos = CompactRelPositionalEncoding(
- pos_dim, dropout_rate=0.15, length_factor=1.0
- )
- if time_embed_dim != -1:
- self.time_emb = nn.Sequential(
- SwooshR(),
- nn.Linear(time_embed_dim, embed_dim),
- )
- else:
- self.time_emb = None
-
- self.layers = nn.ModuleList(
- [copy.deepcopy(encoder_layer) for i in range(num_layers)]
- )
- self.num_layers = num_layers
-
- assert 0 <= warmup_begin <= warmup_end
-
- delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
- cur_begin = warmup_begin # interpreted as a training batch index
- for i in range(num_layers):
- cur_end = cur_begin + delta
- self.layers[i].bypass.skip_rate = ScheduledFloat(
- (cur_begin, initial_layerdrop_rate),
- (cur_end, final_layerdrop_rate),
- default=0.0,
- )
- cur_begin = cur_end
-
- def forward(
- self,
- src: Tensor,
- time_emb: Optional[Tensor] = None,
- attn_mask: Optional[Tensor] = None,
- src_key_padding_mask: Optional[Tensor] = None,
- ) -> Tensor:
- r"""Pass the input through the encoder layers in turn.
-
- Args:
- src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
- the embedding representing the current timestep: shape (batch_size, embedding_dim)
- or (seq_len, batch_size, embedding_dim) .
- time_emb: the embedding representing the current timestep: shape (batch_size, embedding_dim)
- or (seq_len, batch_size, embedding_dim) .
- attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
- interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
- True means masked position. May be None.
- src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
- masked position. May be None.
-
- Returns: a Tensor with the same shape as src.
- """
- pos_emb = self.encoder_pos(src)
- if self.time_emb is not None:
- assert time_emb is not None
- time_emb = self.time_emb(time_emb)
- else:
- assert time_emb is None
-
- output = src
-
- for i, mod in enumerate(self.layers):
- output = mod(
- output,
- pos_emb,
- time_emb=time_emb,
- attn_mask=attn_mask,
- src_key_padding_mask=src_key_padding_mask,
- )
-
- return output
-
-
-class BypassModule(nn.Module):
- """
- An nn.Module that implements a learnable bypass scale, and also randomized per-sequence
- layer-skipping. The bypass is limited during early stages of training to be close to
- "straight-through", i.e. to not do the bypass operation much initially, in order to
- force all the modules to learn something.
- """
-
- def __init__(
- self,
- embed_dim: int,
- skip_rate: FloatLike = 0.0,
- straight_through_rate: FloatLike = 0.0,
- scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
- scale_max: FloatLike = 1.0,
- ):
- super().__init__()
- self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
- self.skip_rate = copy.deepcopy(skip_rate)
- self.straight_through_rate = copy.deepcopy(straight_through_rate)
- self.scale_min = copy.deepcopy(scale_min)
- self.scale_max = copy.deepcopy(scale_max)
-
- def _get_bypass_scale(self, batch_size: int):
- # returns bypass-scale of shape (num_channels,),
- # or (batch_size, num_channels,). This is actually the
- # scale on the non-residual term, so 0 corresponds to bypassing
- # this module.
- if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
- return self.bypass_scale
- else:
- ans = limit_param_value(
- self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max)
- )
- skip_rate = float(self.skip_rate)
- if skip_rate != 0.0:
- mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
- ans = ans * mask
- # now ans is of shape (batch_size, num_channels), and is zero for sequences
- # on which we have randomly chosen to do layer-skipping.
- straight_through_rate = float(self.straight_through_rate)
- if straight_through_rate != 0.0:
- mask = (
- torch.rand((batch_size, 1), device=ans.device)
- < straight_through_rate
- )
- ans = torch.maximum(ans, mask.to(ans.dtype))
- return ans
-
- def forward(self, src_orig: Tensor, src: Tensor):
- """
- Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
- Returns: something with the same shape as src and src_orig
- """
- bypass_scale = self._get_bypass_scale(src.shape[1])
- return src_orig + (src - src_orig) * bypass_scale
-
-
-class DownsampledZipformer2Encoder(nn.Module):
- r"""
- DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
- after convolutional downsampling, and then upsampled again at the output, and combined
- with the origin input, so that the output has the same shape as the input.
- """
-
- def __init__(self, encoder: nn.Module, dim: int, downsample: int):
- super(DownsampledZipformer2Encoder, self).__init__()
- self.downsample_factor = downsample
- self.downsample = SimpleDownsample(downsample)
- self.num_layers = encoder.num_layers
- self.encoder = encoder
- self.upsample = SimpleUpsample(downsample)
- self.out_combiner = BypassModule(dim, straight_through_rate=0)
-
- def forward(
- self,
- src: Tensor,
- time_emb: Optional[Tensor] = None,
- attn_mask: Optional[Tensor] = None,
- src_key_padding_mask: Optional[Tensor] = None,
- ) -> Tensor:
- r"""Downsample, go through encoder, upsample.
-
- Args:
- src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
- time_emb: the embedding representing the current timestep: shape (batch_size, embedding_dim)
- or (seq_len, batch_size, embedding_dim) .
- feature_mask: something that broadcasts with src, that we'll multiply `src`
- by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
- attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
- interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
- True means masked position. May be None.
- src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
- masked position. May be None.
-
- Returns: a Tensor with the same shape as src.
- """
- src_orig = src
- src = self.downsample(src)
- ds = self.downsample_factor
- if time_emb is not None and time_emb.dim() == 3:
- time_emb = time_emb[::ds]
- if attn_mask is not None:
- attn_mask = attn_mask[::ds, ::ds]
- if src_key_padding_mask is not None:
- src_key_padding_mask = src_key_padding_mask[..., ::ds]
-
- src = self.encoder(
- src,
- time_emb=time_emb,
- attn_mask=attn_mask,
- src_key_padding_mask=src_key_padding_mask,
- )
- src = self.upsample(src)
- # remove any extra frames that are not a multiple of downsample_factor
- src = src[: src_orig.shape[0]]
-
- return self.out_combiner(src_orig, src)
-
-
-class SimpleDownsample(torch.nn.Module):
- """
- Does downsampling with attention, by weighted sum.
- """
-
- def __init__(self, downsample: int):
- super(SimpleDownsample, self).__init__()
-
- self.bias = nn.Parameter(torch.zeros(downsample))
-
- self.name = None # will be set from training code
-
- self.downsample = downsample
-
- def forward(self, src: Tensor) -> Tensor:
- """
- x: (seq_len, batch_size, in_channels)
- Returns a tensor of shape
- ( (seq_len+downsample-1)//downsample, batch_size, channels)
- """
- (seq_len, batch_size, in_channels) = src.shape
- ds = self.downsample
- d_seq_len = (seq_len + ds - 1) // ds
-
- # Pad to an exact multiple of self.downsample
- # right-pad src, repeating the last element.
- pad = d_seq_len * ds - seq_len
- src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
- src = torch.cat((src, src_extra), dim=0)
- assert src.shape[0] == d_seq_len * ds
-
- src = src.reshape(d_seq_len, ds, batch_size, in_channels)
-
- weights = self.bias.softmax(dim=0)
- # weights: (downsample, 1, 1)
- weights = weights.unsqueeze(-1).unsqueeze(-1)
-
- # ans1 is the first `in_channels` channels of the output
- ans = (src * weights).sum(dim=1)
-
- return ans
-
-
-class SimpleUpsample(torch.nn.Module):
- """
- A very simple form of upsampling that just repeats the input.
- """
-
- def __init__(self, upsample: int):
- super(SimpleUpsample, self).__init__()
- self.upsample = upsample
-
- def forward(self, src: Tensor) -> Tensor:
- """
- x: (seq_len, batch_size, num_channels)
- Returns a tensor of shape
- ( (seq_len*upsample), batch_size, num_channels)
- """
- upsample = self.upsample
- (seq_len, batch_size, num_channels) = src.shape
- src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
- src = src.reshape(seq_len * upsample, batch_size, num_channels)
- return src
-
-
-class CompactRelPositionalEncoding(torch.nn.Module):
- """
- Relative positional encoding module. This version is "compact" meaning it is able to encode
- the important information about the relative position in a relatively small number of dimensions.
- The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001)
- make very little difference to the embedding. Such differences were potentially important
- when encoding absolute position, but not important when encoding relative position because there
- is now no need to compare two large offsets with each other.
-
- Our embedding works by projecting the interval [-infinity,infinity] to a finite interval
- using the atan() function, before doing the Fourier transform of that fixed interval. The
- atan() function would compress the "long tails" too small,
- making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic
- function to compress large offsets to a smaller range before applying atan().
- Scalings are chosen in such a way that the embedding can clearly distinguish individual offsets as long
- as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim)
-
-
- Args:
- embed_dim: Embedding dimension.
- dropout_rate: Dropout rate.
- max_len: Maximum input length: just a heuristic for initialization.
- length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
- less weight to small differences of offset near the origin.
- """
-
- def __init__(
- self,
- embed_dim: int,
- dropout_rate: FloatLike,
- max_len: int = 1000,
- length_factor: float = 1.0,
- ) -> None:
- """Construct a CompactRelPositionalEncoding object."""
- super(CompactRelPositionalEncoding, self).__init__()
- self.embed_dim = embed_dim
- assert embed_dim % 2 == 0, embed_dim
- self.dropout = Dropout2(dropout_rate)
- self.pe = None
- assert length_factor >= 1.0, length_factor
- self.length_factor = length_factor
- self.extend_pe(torch.tensor(0.0).expand(max_len))
-
- def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
- """Reset the positional encodings."""
- T = x.size(0) + left_context_len
-
- if self.pe is not None:
- # self.pe contains both positive and negative parts
- # the length of self.pe is 2 * input_len - 1
- if self.pe.size(0) >= T * 2 - 1:
- self.pe = self.pe.to(dtype=x.dtype, device=x.device)
- return
-
- # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
- x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
-
- freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
-
- # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution
- # for small time offsets but less resolution for large time offsets.
- compression_length = self.embed_dim**0.5
- # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity;
- # but it does so more slowly than T for large absolute values of T.
- # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which
- # is important.
- x_compressed = (
- compression_length
- * x.sign()
- * ((x.abs() + compression_length).log() - math.log(compression_length))
- )
-
- # if self.length_factor == 1.0, then length_scale is chosen so that the
- # FFT can exactly separate points close to the origin (T == 0). So this
- # part of the formulation is not really heuristic.
- # But empirically, for ASR at least, length_factor > 1.0 seems to work better.
- length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
-
- # note for machine implementations: if atan is not available, we can use:
- # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
- # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
- x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
-
- cosines = (x_atan * freqs).cos()
- sines = (x_atan * freqs).sin()
-
- pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
- pe[:, 0::2] = cosines
- pe[:, 1::2] = sines
- pe[:, -1] = 1.0 # for bias.
-
- self.pe = pe.to(dtype=x.dtype)
-
- def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
- """Create positional encoding.
-
- Args:
- x (Tensor): Input tensor (time, batch, `*`).
- left_context_len: (int): Length of cached left context.
-
- Returns:
- positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
- """
- self.extend_pe(x, left_context_len)
- x_size_left = x.size(0) + left_context_len
- # length of positive side: x.size(0) + left_context_len
- # length of negative side: x.size(0)
- pos_emb = self.pe[
- self.pe.size(0) // 2
- - x_size_left
- + 1 : self.pe.size(0) // 2 # noqa E203
- + x.size(0),
- :,
- ]
- pos_emb = pos_emb.unsqueeze(0)
- return self.dropout(pos_emb)
-
-
-class RelPositionMultiheadAttentionWeights(nn.Module):
- r"""Module that computes multi-head attention weights with relative position encoding.
- Various other modules consume the resulting attention weights: see, for example, the
- SimpleAttention module which allows you to compute conventional attention.
-
- This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context",
- we have to write up the differences.
-
-
- Args:
- embed_dim: number of channels at the input to this module, e.g. 256
- pos_dim: dimension of the positional encoding vectors, e.g. 128.
- num_heads: number of heads to compute weights for, e.g. 8
- query_head_dim: dimension of the query (and key), per head. e.g. 24.
- pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
- dropout: dropout probability for attn_output_weights. Default: 0.0.
- pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
- any given call to forward(), in training time.
- """
-
- def __init__(
- self,
- embed_dim: int,
- pos_dim: int,
- num_heads: int,
- query_head_dim: int,
- pos_head_dim: int,
- dropout: float = 0.0,
- pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
- ) -> None:
- super().__init__()
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.query_head_dim = query_head_dim
- self.pos_head_dim = pos_head_dim
- self.dropout = dropout
- self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
- self.name = None # will be overwritten in training code; for diagnostics.
-
- key_head_dim = query_head_dim
- in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
-
- # the initial_scale is supposed to take over the "scaling" factor of
- # head_dim ** -0.5 that has been used in previous forms of attention,
- # dividing it between the query and key. Note: this module is intended
- # to be used with the ScaledAdam optimizer; with most other optimizers,
- # it would be necessary to apply the scaling factor in the forward function.
- self.in_proj = ScaledLinear(
- embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
- )
-
- self.whiten_keys = Whiten(
- num_groups=num_heads,
- whitening_limit=_whitening_schedule(3.0),
- prob=(0.025, 0.25),
- grad_scale=0.025,
- )
-
- # add a balancer for the keys that runs with very small probability, and
- # tries to enforce that all dimensions have mean around zero. The
- # weights produced by this module are invariant to adding a constant to
- # the keys, so the derivative of the bias is mathematically zero; but
- # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
- # bias because the small numerical roundoff tends to have a non-random
- # sign. This module is intended to prevent that. Use a very small
- # probability; that should be sufficient to fix the problem.
- self.balance_keys = Balancer(
- key_head_dim * num_heads,
- channel_dim=-1,
- min_positive=0.4,
- max_positive=0.6,
- min_abs=0.0,
- max_abs=100.0,
- prob=0.025,
- )
-
- # linear transformation for positional encoding.
- self.linear_pos = ScaledLinear(
- pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05
- )
-
- # the following are for diagnostics only, see --print-diagnostics option
- self.copy_pos_query = Identity()
- self.copy_query = Identity()
-
- def forward(
- self,
- x: Tensor,
- pos_emb: Tensor,
- key_padding_mask: Optional[Tensor] = None,
- attn_mask: Optional[Tensor] = None,
- ) -> Tensor:
- r"""
- Args:
- x: input of shape (seq_len, batch_size, embed_dim)
- pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
- key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
- are True in this mask will be ignored as sources in the attention weighting.
- attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
- interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
- saying which positions are allowed to attend to which other positions.
- Returns:
- a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
- interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
- """
- x = self.in_proj(x)
- query_head_dim = self.query_head_dim
- pos_head_dim = self.pos_head_dim
- num_heads = self.num_heads
-
- seq_len, batch_size, _ = x.shape
-
- query_dim = query_head_dim * num_heads
-
- # self-attention
- q = x[..., 0:query_dim]
- k = x[..., query_dim : 2 * query_dim]
- # p is the position-encoding query
- p = x[..., 2 * query_dim :]
- assert p.shape[-1] == num_heads * pos_head_dim, (
- p.shape[-1],
- num_heads,
- pos_head_dim,
- )
-
- q = self.copy_query(q) # for diagnostics only, does nothing.
- k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
- p = self.copy_pos_query(p) # for diagnostics only, does nothing.
-
- q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
- p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
- k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
-
- # time1 refers to target, time2 refers to source.
- q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
- p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
- k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
-
- attn_scores = torch.matmul(q, k)
-
- use_pos_scores = False
- if torch.jit.is_scripting() or torch.jit.is_tracing():
- # We can't put random.random() in the same line
- use_pos_scores = True
- elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
- use_pos_scores = True
-
- if use_pos_scores:
- pos_emb = self.linear_pos(pos_emb)
- seq_len2 = 2 * seq_len - 1
- pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
- 2, 0, 3, 1
- )
- # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
-
- # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
- # [where seq_len2 represents relative position.]
- pos_scores = torch.matmul(p, pos_emb)
- # the following .as_strided() expression converts the last axis of pos_scores from relative
- # to absolute position. I don't know whether I might have got the time-offsets backwards or
- # not, but let this code define which way round it is supposed to be.
- if torch.jit.is_tracing():
- (num_heads, batch_size, time1, n) = pos_scores.shape
- rows = torch.arange(start=time1 - 1, end=-1, step=-1)
- cols = torch.arange(seq_len)
- rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
- indexes = rows + cols
- pos_scores = pos_scores.reshape(-1, n)
- pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
- pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
- else:
- pos_scores = pos_scores.as_strided(
- (num_heads, batch_size, seq_len, seq_len),
- (
- pos_scores.stride(0),
- pos_scores.stride(1),
- pos_scores.stride(2) - pos_scores.stride(3),
- pos_scores.stride(3),
- ),
- storage_offset=pos_scores.stride(3) * (seq_len - 1),
- )
-
- attn_scores = attn_scores + pos_scores
-
- if torch.jit.is_scripting() or torch.jit.is_tracing():
- pass
- elif self.training and random.random() < 0.1:
- # This is a harder way of limiting the attention scores to not be
- # too large. It incurs a penalty if any of them has an absolute
- # value greater than 50.0. this should be outside the normal range
- # of the attention scores. We use this mechanism instead of, say,
- # something added to the loss function involving the entropy,
- # because once the entropy gets very small gradients through the
- # softmax can become very small, and we'd get zero derivatives. The
- # choices of 1.0e-04 as the scale on the penalty makes this
- # mechanism vulnerable to the absolute scale of the loss function,
- # but we view this as a failsafe to avoid "implausible" parameter
- # values rather than a regularization method that should be active
- # under normal circumstances.
- attn_scores = penalize_abs_values_gt(
- attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
- )
-
- assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
-
- if attn_mask is not None:
- assert attn_mask.dtype == torch.bool
- # use -1000 to avoid nan's where attn_mask and key_padding_mask make
- # all scores zero. It's important that this be large enough that exp(-1000)
- # is exactly zero, for reasons related to const_attention_rate, it
- # compares the final weights with zero.
- attn_scores = attn_scores.masked_fill(attn_mask, -1000)
-
- if key_padding_mask is not None:
- assert key_padding_mask.shape == (
- batch_size,
- seq_len,
- ), key_padding_mask.shape
- attn_scores = attn_scores.masked_fill(
- key_padding_mask.unsqueeze(1),
- -1000,
- )
-
- # We use our own version of softmax, defined in scaling.py, which should
- # save a little of the memory used in backprop by, if we are in
- # automatic mixed precision mode (amp / autocast), by only storing the
- # half-precision output for backprop purposes.
- attn_weights = softmax(attn_scores, dim=-1)
-
- if torch.jit.is_scripting() or torch.jit.is_tracing():
- pass
- elif random.random() < 0.001 and not self.training:
- self._print_attn_entropy(attn_weights)
-
- attn_weights = nn.functional.dropout(
- attn_weights, p=self.dropout, training=self.training
- )
-
- return attn_weights
-
- def _print_attn_entropy(self, attn_weights: Tensor):
- # attn_weights: (num_heads, batch_size, seq_len, seq_len)
- (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
-
- with torch.no_grad():
- with torch.amp.autocast("cuda", enabled=False):
- attn_weights = attn_weights.to(torch.float32)
- attn_weights_entropy = (
- -((attn_weights + 1.0e-20).log() * attn_weights)
- .sum(dim=-1)
- .mean(dim=(1, 2))
- )
- logging.debug(
- f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}"
- )
-
-
-class SelfAttention(nn.Module):
- """
- The simplest possible attention module. This one works with already-computed attention
- weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
-
- Args:
- embed_dim: the input and output embedding dimension
- num_heads: the number of attention heads
- value_head_dim: the value dimension per head
- """
-
- def __init__(
- self,
- embed_dim: int,
- num_heads: int,
- value_head_dim: int,
- ) -> None:
- super().__init__()
- self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
-
- self.out_proj = ScaledLinear(
- num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
- )
-
- self.whiten = Whiten(
- num_groups=1,
- whitening_limit=_whitening_schedule(7.5, ratio=3.0),
- prob=(0.025, 0.25),
- grad_scale=0.01,
- )
-
- def forward(
- self,
- x: Tensor,
- attn_weights: Tensor,
- ) -> Tensor:
- """
- Args:
- x: input tensor, of shape (seq_len, batch_size, embed_dim)
- attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
- with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
- attn_weights.sum(dim=-1) == 1.
- Returns:
- a tensor with the same shape as x.
- """
- (seq_len, batch_size, embed_dim) = x.shape
- num_heads = attn_weights.shape[0]
- assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
-
- x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
- x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
- # now x: (num_heads, batch_size, seq_len, value_head_dim)
- value_head_dim = x.shape[-1]
-
- # todo: see whether there is benefit in overriding matmul
- x = torch.matmul(attn_weights, x)
- # v: (num_heads, batch_size, seq_len, value_head_dim)
-
- x = (
- x.permute(2, 1, 0, 3)
- .contiguous()
- .view(seq_len, batch_size, num_heads * value_head_dim)
- )
-
- # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
- x = self.out_proj(x)
- x = self.whiten(x)
-
- return x
-
-
-class FeedforwardModule(nn.Module):
- """Feedforward module in TTSZipformer model."""
-
- def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike):
- super(FeedforwardModule, self).__init__()
- self.in_proj = nn.Linear(embed_dim, feedforward_dim)
-
- self.hidden_balancer = Balancer(
- feedforward_dim,
- channel_dim=-1,
- min_positive=0.3,
- max_positive=1.0,
- min_abs=0.75,
- max_abs=5.0,
- )
-
- # shared_dim=0 means we share the dropout mask along the time axis
- self.out_proj = ActivationDropoutAndLinear(
- feedforward_dim,
- embed_dim,
- activation="SwooshL",
- dropout_p=dropout,
- dropout_shared_dim=0,
- bias=True,
- initial_scale=0.1,
- )
-
- self.out_whiten = Whiten(
- num_groups=1,
- whitening_limit=_whitening_schedule(7.5),
- prob=(0.025, 0.25),
- grad_scale=0.01,
- )
-
- def forward(self, x: Tensor):
- x = self.in_proj(x)
- x = self.hidden_balancer(x)
- # out_proj contains SwooshL activation, then dropout, then linear.
- x = self.out_proj(x)
- x = self.out_whiten(x)
- return x
-
-
-class NonlinAttention(nn.Module):
- """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
- from the attention module) in place of actual convolution. We also took out the second nonlinearity, the
- one after the attention mechanism.
-
- Args:
- channels (int): The number of channels of conv layers.
- """
-
- def __init__(
- self,
- channels: int,
- hidden_channels: int,
- ) -> None:
- super().__init__()
-
- self.hidden_channels = hidden_channels
-
- self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
-
- # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0,
- # because we noticed that well-trained instances of this module have abs-value before the sigmoid
- # starting from about 3, and poorly-trained instances of the module have smaller abs values
- # before the sigmoid.
- self.balancer = Balancer(
- hidden_channels,
- channel_dim=-1,
- min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
- max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
- min_abs=0.5,
- max_abs=5.0,
- )
- self.tanh = nn.Tanh()
-
- self.identity1 = Identity() # for diagnostics.
- self.identity2 = Identity() # for diagnostics.
- self.identity3 = Identity() # for diagnostics.
-
- self.out_proj = ScaledLinear(
- hidden_channels, channels, bias=True, initial_scale=0.05
- )
-
- self.whiten1 = Whiten(
- num_groups=1,
- whitening_limit=_whitening_schedule(5.0),
- prob=(0.025, 0.25),
- grad_scale=0.01,
- )
-
- self.whiten2 = Whiten(
- num_groups=1,
- whitening_limit=_whitening_schedule(5.0, ratio=3.0),
- prob=(0.025, 0.25),
- grad_scale=0.01,
- )
-
- def forward(
- self,
- x: Tensor,
- attn_weights: Tensor,
- ) -> Tensor:
- """.
- Args:
- x: a Tensor of shape (seq_len, batch_size, num_channels)
- attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
- Returns:
- a Tensor with the same shape as x
- """
- x = self.in_proj(x)
-
- (seq_len, batch_size, _) = x.shape
- hidden_channels = self.hidden_channels
-
- s, x, y = x.chunk(3, dim=2)
-
- # s will go through tanh.
-
- s = self.balancer(s)
- s = self.tanh(s)
-
- s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
- x = self.whiten1(x)
- x = x * s
- x = self.identity1(x) # diagnostics only, it's the identity.
-
- (seq_len, batch_size, embed_dim) = x.shape
- num_heads = attn_weights.shape[0]
- assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
-
- x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
- # now x: (num_heads, batch_size, seq_len, head_dim)
- x = torch.matmul(attn_weights, x)
- # now x: (num_heads, batch_size, seq_len, head_dim)
- x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
-
- y = self.identity2(y)
- x = x * y
- x = self.identity3(x)
-
- x = self.out_proj(x)
- x = self.whiten2(x)
- return x
-
-
-class ConvolutionModule(nn.Module):
- """ConvolutionModule in Zipformer2 model.
- Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
-
- Args:
- channels (int): The number of channels of conv layers.
- kernel_size (int): Kernerl size of conv layers.
- bias (bool): Whether to use bias in conv layers (default=True).
-
- """
-
- def __init__(
- self,
- channels: int,
- kernel_size: int,
- ) -> None:
- """Construct a ConvolutionModule object."""
- super(ConvolutionModule, self).__init__()
- # kernerl_size should be a odd number for 'SAME' padding
- assert (kernel_size - 1) % 2 == 0
-
- bottleneck_dim = channels
-
- self.in_proj = nn.Linear(
- channels,
- 2 * bottleneck_dim,
- )
- # the gradients on in_proj are a little noisy, likely to do with the
- # sigmoid in glu.
-
- # after in_proj we put x through a gated linear unit (nn.functional.glu).
- # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
- # but sometimes, for some reason, for layer 0 the rms ends up being very large,
- # between 50 and 100 for different channels. This will cause very peaky and
- # sparse derivatives for the sigmoid gating function, which will tend to make
- # the loss function not learn effectively. (for most layers the average absolute values
- # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
- # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
- # layers, which likely breaks down as 0.5 for the "linear" half and
- # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
- # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
- # it will be in a better position to start learning something, i.e. to latch onto
- # the correct range.
- self.balancer1 = Balancer(
- bottleneck_dim,
- channel_dim=-1,
- min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
- max_positive=1.0,
- min_abs=1.5,
- max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
- )
-
- self.activation1 = Identity() # for diagnostics
-
- self.sigmoid = nn.Sigmoid()
-
- self.activation2 = Identity() # for diagnostics
-
- assert kernel_size % 2 == 1
-
- self.depthwise_conv = nn.Conv1d(
- in_channels=bottleneck_dim,
- out_channels=bottleneck_dim,
- groups=bottleneck_dim,
- kernel_size=kernel_size,
- padding=kernel_size // 2,
- )
-
- self.balancer2 = Balancer(
- bottleneck_dim,
- channel_dim=1,
- min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
- max_positive=1.0,
- min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
- max_abs=10.0,
- )
-
- self.whiten = Whiten(
- num_groups=1,
- whitening_limit=_whitening_schedule(7.5),
- prob=(0.025, 0.25),
- grad_scale=0.01,
- )
-
- self.out_proj = ActivationDropoutAndLinear(
- bottleneck_dim,
- channels,
- activation="SwooshR",
- dropout_p=0.0,
- initial_scale=0.05,
- )
-
- def forward(
- self,
- x: Tensor,
- src_key_padding_mask: Optional[Tensor] = None,
- ) -> Tensor:
- """Compute convolution module.
-
- Args:
- x: Input tensor (#time, batch, channels).
- src_key_padding_mask: the mask for the src keys per batch (optional):
- (batch, #time), contains True in masked positions.
-
- Returns:
- Tensor: Output tensor (#time, batch, channels).
-
- """
-
- x = self.in_proj(x) # (time, batch, 2*channels)
-
- x, s = x.chunk(2, dim=2)
- s = self.balancer1(s)
- s = self.sigmoid(s)
- x = self.activation1(x) # identity.
- x = x * s
- x = self.activation2(x) # identity
-
- # (time, batch, channels)
-
- # exchange the temporal dimension and the feature dimension
- x = x.permute(1, 2, 0) # (#batch, channels, time).
-
- if src_key_padding_mask is not None:
- x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
-
- x = self.depthwise_conv(x)
-
- x = self.balancer2(x)
- x = x.permute(2, 0, 1) # (time, batch, channels)
-
- x = self.whiten(x) # (time, batch, channels)
- x = self.out_proj(x) # (time, batch, channels)
-
- return x
diff --git a/egs/zipvoice/zipvoice/zipvoice_infer.py b/egs/zipvoice/zipvoice/zipvoice_infer.py
deleted file mode 100644
index 16cf8e039..000000000
--- a/egs/zipvoice/zipvoice/zipvoice_infer.py
+++ /dev/null
@@ -1,645 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-This script generates speech with our pre-trained ZipVoice or
- ZipVoice-Distill models. Required models will be automatically
- downloaded from HuggingFace.
-
-Usage:
-
-Note: If you having trouble connecting to HuggingFace,
- try switching endpoint to mirror site:
-
-export HF_ENDPOINT=https://hf-mirror.com
-
-(1) Inference of a single sentence:
-
-python3 zipvoice/zipvoice_infer.py \
- --model-name "zipvoice_distill" \
- --prompt-wav prompt.wav \
- --prompt-text "I am a prompt." \
- --text "I am a sentence." \
- --res-wav-path result.wav
-
-(2) Inference of a list of sentences:
-python3 zipvoice/zipvoice_infer.py \
- --model-name "zipvoice-distill" \
- --test-list test.tsv \
- --res-dir results
-
-`--model-name` can be `zipvoice` or `zipvoice_distill`,
- which are the models before and after distillation, respectively.
-
-Each line of `test.tsv` is in the format of
- `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
-"""
-
-import argparse
-import datetime as dt
-import os
-
-import numpy as np
-import safetensors.torch
-import torch
-import torch.nn as nn
-import torchaudio
-from feature import TorchAudioFbank, TorchAudioFbankConfig
-from huggingface_hub import hf_hub_download
-from lhotse.utils import fix_random_seed
-from model import get_distill_model, get_model
-from tokenizer import TokenizerEmilia
-from utils import AttributeDict
-from vocos import Vocos
-
-
-def get_parser():
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter
- )
-
- parser.add_argument(
- "--model-name",
- type=str,
- default="zipvoice_distill",
- choices=["zipvoice", "zipvoice_distill"],
- help="The model used for inference",
- )
-
- parser.add_argument(
- "--test-list",
- type=str,
- default=None,
- help="The list of prompt speech, prompt_transcription, "
- "and text to synthesizein the format of "
- "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
- )
-
- parser.add_argument(
- "--prompt-wav",
- type=str,
- default=None,
- help="The prompt wav to mimic",
- )
-
- parser.add_argument(
- "--prompt-text",
- type=str,
- default=None,
- help="The transcription of the prompt wav",
- )
-
- parser.add_argument(
- "--text",
- type=str,
- default=None,
- help="The text to synthesize",
- )
-
- parser.add_argument(
- "--res-dir",
- type=str,
- default="results",
- help="""
- Path name of the generated wavs dir,
- used when test-list is not None
- """,
- )
-
- parser.add_argument(
- "--res-wav-path",
- type=str,
- default="result.wav",
- help="""
- Path name of the generated wav path,
- used when test-list is None
- """,
- )
-
- parser.add_argument(
- "--guidance-scale",
- type=float,
- default=None,
- help="The scale of classifier-free guidance during inference.",
- )
-
- parser.add_argument(
- "--num-step",
- type=int,
- default=None,
- help="The number of sampling steps.",
- )
-
- parser.add_argument(
- "--feat-scale",
- type=float,
- default=0.1,
- help="The scale factor of fbank feature",
- )
-
- parser.add_argument(
- "--speed",
- type=float,
- default=1.0,
- help="Control speech speed, 1.0 means normal, >1.0 means speed up",
- )
-
- parser.add_argument(
- "--t-shift",
- type=float,
- default=0.5,
- help="Shift t to smaller ones if t_shift < 1.0",
- )
-
- parser.add_argument(
- "--target-rms",
- type=float,
- default=0.1,
- help="Target speech normalization rms value",
- )
-
- parser.add_argument(
- "--seed",
- type=int,
- default=666,
- help="Random seed",
- )
-
- add_model_arguments(parser)
-
- return parser
-
-
-def add_model_arguments(parser: argparse.ArgumentParser):
- parser.add_argument(
- "--fm-decoder-downsampling-factor",
- type=str,
- default="1,2,4,2,1",
- help="Downsampling factor for each stack of encoder layers.",
- )
-
- parser.add_argument(
- "--fm-decoder-num-layers",
- type=str,
- default="2,2,4,4,4",
- help="Number of zipformer encoder layers per stack, comma separated.",
- )
-
- parser.add_argument(
- "--fm-decoder-cnn-module-kernel",
- type=str,
- default="31,15,7,15,31",
- help="Sizes of convolutional kernels in convolution modules "
- "in each encoder stack: a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--fm-decoder-feedforward-dim",
- type=int,
- default=1536,
- help="Feedforward dimension of the zipformer encoder layers, "
- "per stack, comma separated.",
- )
-
- parser.add_argument(
- "--fm-decoder-num-heads",
- type=int,
- default=4,
- help="Number of attention heads in the zipformer encoder layers: "
- "a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--fm-decoder-dim",
- type=int,
- default=512,
- help="Embedding dimension in encoder stacks: a single int "
- "or comma-separated list.",
- )
-
- parser.add_argument(
- "--text-encoder-downsampling-factor",
- type=str,
- default="1",
- help="Downsampling factor for each stack of encoder layers.",
- )
-
- parser.add_argument(
- "--text-encoder-num-layers",
- type=str,
- default="4",
- help="Number of zipformer encoder layers per stack, comma separated.",
- )
-
- parser.add_argument(
- "--text-encoder-feedforward-dim",
- type=int,
- default=512,
- help="Feedforward dimension of the zipformer encoder layers, "
- "per stack, comma separated.",
- )
-
- parser.add_argument(
- "--text-encoder-cnn-module-kernel",
- type=str,
- default="9",
- help="Sizes of convolutional kernels in convolution modules in "
- "each encoder stack: a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--text-encoder-num-heads",
- type=int,
- default=4,
- help="Number of attention heads in the zipformer encoder layers: "
- "a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--text-encoder-dim",
- type=int,
- default=192,
- help="Embedding dimension in encoder stacks: "
- "a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--query-head-dim",
- type=int,
- default=32,
- help="Query/key dimension per head in encoder stacks: "
- "a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--value-head-dim",
- type=int,
- default=12,
- help="Value dimension per head in encoder stacks: "
- "a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--pos-head-dim",
- type=int,
- default=4,
- help="Positional-encoding dimension per head in encoder stacks: "
- "a single int or comma-separated list.",
- )
-
- parser.add_argument(
- "--pos-dim",
- type=int,
- default=48,
- help="Positional-encoding embedding dimension",
- )
-
- parser.add_argument(
- "--time-embed-dim",
- type=int,
- default=192,
- help="Embedding dimension of timestamps embedding.",
- )
-
- parser.add_argument(
- "--text-embed-dim",
- type=int,
- default=192,
- help="Embedding dimension of text embedding.",
- )
-
-
-def get_params() -> AttributeDict:
- params = AttributeDict(
- {
- "sampling_rate": 24000,
- "frame_shift_ms": 256 / 24000 * 1000,
- "feat_dim": 100,
- }
- )
-
- return params
-
-
-def get_vocoder():
- vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
- return vocoder
-
-
-def generate_sentence(
- save_path: str,
- prompt_text: str,
- prompt_wav: str,
- text: str,
- model: nn.Module,
- vocoder: nn.Module,
- tokenizer: TokenizerEmilia,
- feature_extractor: TorchAudioFbank,
- device: torch.device,
- num_step: int = 16,
- guidance_scale: float = 1.0,
- speed: float = 1.0,
- t_shift: float = 0.5,
- target_rms: float = 0.1,
- feat_scale: float = 0.1,
- sampling_rate: int = 24000,
-):
- """
- Generate waveform of a text based on a given prompt
- waveform and its transcription.
-
- Args:
- save_path (str): Path to save the generated wav.
- prompt_text (str): Transcription of the prompt wav.
- prompt_wav (str): Path to the prompt wav file.
- text (str): Text to be synthesized into a waveform.
- model (nn.Module): The model used for generation.
- vocoder (nn.Module): The vocoder used to convert features to waveforms.
- tokenizer (TokenizerEmilia): The tokenizer used to convert text to tokens.
- feature_extractor (TorchAudioFbank): The feature extractor used to
- extract acoustic features.
- device (torch.device): The device on which computations are performed.
- num_step (int, optional): Number of steps for decoding. Defaults to 16.
- guidance_scale (float, optional): Scale for classifier-free guidance.
- Defaults to 1.0.
- speed (float, optional): Speed control. Defaults to 1.0.
- t_shift (float, optional): Time shift. Defaults to 0.5.
- target_rms (float, optional): Target RMS for waveform normalization.
- Defaults to 0.1.
- feat_scale (float, optional): Scale for features.
- Defaults to 0.1.
- sampling_rate (int, optional): Sampling rate for the waveform.
- Defaults to 24000.
- Returns:
- metrics (dict): Dictionary containing time and real-time
- factor metrics for processing.
- """
- # Convert text to tokens
- tokens = tokenizer.texts_to_token_ids([text])
- prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
-
- # Load and preprocess prompt wav
- prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
- prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
- if prompt_rms < target_rms:
- prompt_wav = prompt_wav * target_rms / prompt_rms
-
- if prompt_sampling_rate != sampling_rate:
- resampler = torchaudio.transforms.Resample(
- orig_freq=prompt_sampling_rate, new_freq=sampling_rate
- )
- prompt_wav = resampler(prompt_wav)
-
- # Extract features from prompt wav
- prompt_features = feature_extractor.extract(
- prompt_wav, sampling_rate=sampling_rate
- ).to(device)
- prompt_features = prompt_features.unsqueeze(0) * feat_scale
- prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
-
- # Start timing
- start_t = dt.datetime.now()
-
- # Generate features
- (
- pred_features,
- pred_features_lens,
- pred_prompt_features,
- pred_prompt_features_lens,
- ) = model.sample(
- tokens=tokens,
- prompt_tokens=prompt_tokens,
- prompt_features=prompt_features,
- prompt_features_lens=prompt_features_lens,
- speed=speed,
- t_shift=t_shift,
- duration="predict",
- num_step=num_step,
- guidance_scale=guidance_scale,
- )
-
- # Postprocess predicted features
- pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
-
- # Start vocoder processing
- start_vocoder_t = dt.datetime.now()
- wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
-
- # Calculate processing times and real-time factors
- t = (dt.datetime.now() - start_t).total_seconds()
- t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
- t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
- wav_seconds = wav.shape[-1] / sampling_rate
- rtf = t / wav_seconds
- rtf_no_vocoder = t_no_vocoder / wav_seconds
- rtf_vocoder = t_vocoder / wav_seconds
- metrics = {
- "t": t,
- "t_no_vocoder": t_no_vocoder,
- "t_vocoder": t_vocoder,
- "wav_seconds": wav_seconds,
- "rtf": rtf,
- "rtf_no_vocoder": rtf_no_vocoder,
- "rtf_vocoder": rtf_vocoder,
- }
-
- # Adjust wav volume if necessary
- if prompt_rms < target_rms:
- wav = wav * prompt_rms / target_rms
- torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
-
- return metrics
-
-
-def generate(
- res_dir: str,
- test_list: str,
- model: nn.Module,
- vocoder: nn.Module,
- tokenizer: TokenizerEmilia,
- feature_extractor: TorchAudioFbank,
- device: torch.device,
- num_step: int = 16,
- guidance_scale: float = 1.0,
- speed: float = 1.0,
- t_shift: float = 0.5,
- target_rms: float = 0.1,
- feat_scale: float = 0.1,
- sampling_rate: int = 24000,
-):
- total_t = []
- total_t_no_vocoder = []
- total_t_vocoder = []
- total_wav_seconds = []
-
- with open(test_list, "r") as fr:
- lines = fr.readlines()
-
- for i, line in enumerate(lines):
- wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
- save_path = f"{res_dir}/{wav_name}.wav"
- metrics = generate_sentence(
- save_path=save_path,
- prompt_text=prompt_text,
- prompt_wav=prompt_wav,
- text=text,
- model=model,
- vocoder=vocoder,
- tokenizer=tokenizer,
- feature_extractor=feature_extractor,
- device=device,
- num_step=num_step,
- guidance_scale=guidance_scale,
- speed=speed,
- t_shift=t_shift,
- target_rms=target_rms,
- feat_scale=feat_scale,
- sampling_rate=sampling_rate,
- )
- print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
- total_t.append(metrics["t"])
- total_t_no_vocoder.append(metrics["t_no_vocoder"])
- total_t_vocoder.append(metrics["t_vocoder"])
- total_wav_seconds.append(metrics["wav_seconds"])
-
- print(f"Average RTF: {np.sum(total_t)/np.sum(total_wav_seconds):.4f}")
- print(
- f"Average RTF w/o vocoder: "
- f"{np.sum(total_t_no_vocoder)/np.sum(total_wav_seconds):.4f}"
- )
- print(
- f"Average RTF vocoder: "
- f"{np.sum(total_t_vocoder)/np.sum(total_wav_seconds):.4f}"
- )
-
-
-@torch.inference_mode()
-def main():
- parser = get_parser()
- args = parser.parse_args()
-
- params = get_params()
- params.update(vars(args))
-
- model_defaults = {
- "zipvoice": {
- "num_step": 16,
- "guidance_scale": 1.0,
- },
- "zipvoice_distill": {
- "num_step": 8,
- "guidance_scale": 3.0,
- },
- }
-
- model_specific_defaults = model_defaults.get(params.model_name, {})
-
- for param, value in model_specific_defaults.items():
- if getattr(params, param) == parser.get_default(param):
- setattr(params, param, value)
- print(f"Setting {param} to default value: {value}")
-
- assert (params.test_list is not None) ^ (
- (params.prompt_wav and params.prompt_text and params.text) is not None
- ), (
- "For inference, please provide prompts and text with either '--test-list'"
- " or '--prompt-wav, --prompt-text and --text'."
- )
-
- if torch.cuda.is_available():
- params.device = torch.device("cuda", 0)
- else:
- params.device = torch.device("cpu")
-
- token_file = hf_hub_download("zhu-han/ZipVoice", filename="tokens_emilia.txt")
-
- tokenizer = TokenizerEmilia(token_file)
-
- params.vocab_size = tokenizer.vocab_size
- params.pad_id = tokenizer.pad_id
- fix_random_seed(params.seed)
-
- if params.model_name == "zipvoice_distill":
- model = get_distill_model(params)
- model_ckpt = hf_hub_download(
- "zhu-han/ZipVoice", filename="exp_zipvoice_distill/model.safetensors"
- )
- else:
- model = get_model(params)
- model_ckpt = hf_hub_download(
- "zhu-han/ZipVoice", filename="exp_zipvoice/model.safetensors"
- )
-
- safetensors.torch.load_model(model, model_ckpt)
-
- model = model.to(params.device)
- model.eval()
-
- vocoder = get_vocoder()
- vocoder = vocoder.to(params.device)
- vocoder.eval()
-
- config = TorchAudioFbankConfig(
- sampling_rate=params.sampling_rate,
- n_mels=100,
- n_fft=1024,
- hop_length=256,
- )
- feature_extractor = TorchAudioFbank(config)
-
- if params.test_list:
- os.makedirs(params.res_dir, exist_ok=True)
- generate(
- res_dir=params.res_dir,
- test_list=params.test_list,
- model=model,
- vocoder=vocoder,
- tokenizer=tokenizer,
- feature_extractor=feature_extractor,
- device=params.device,
- num_step=params.num_step,
- guidance_scale=params.guidance_scale,
- speed=params.speed,
- t_shift=params.t_shift,
- target_rms=params.target_rms,
- feat_scale=params.feat_scale,
- sampling_rate=params.sampling_rate,
- )
- else:
- generate_sentence(
- save_path=params.res_wav_path,
- prompt_text=params.prompt_text,
- prompt_wav=params.prompt_wav,
- text=params.text,
- model=model,
- vocoder=vocoder,
- tokenizer=tokenizer,
- feature_extractor=feature_extractor,
- device=params.device,
- num_step=params.num_step,
- guidance_scale=params.guidance_scale,
- speed=params.speed,
- t_shift=params.t_shift,
- target_rms=params.target_rms,
- feat_scale=params.feat_scale,
- sampling_rate=params.sampling_rate,
- )
- print("Done")
-
-
-if __name__ == "__main__":
- main()