Add ZipVoice

This commit is contained in:
Han Zhu 2025-06-16 09:45:34 +08:00
parent ffb7d05635
commit 9936d726d2
28 changed files with 13142 additions and 0 deletions

360
egs/zipvoice/README.md Normal file
View File

@ -0,0 +1,360 @@
## ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching
[![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/)
[![demo](https://img.shields.io/badge/GitHub-Demo%20page-orange.svg)](https://zipvoice.github.io/)
## Overview
ZipVoice is a high-quality zero-shot TTS model with a small model size and fast inference speed.
#### Key features:
- Small and fast: only 123M parameters.
- High-quality: state-of-the-art voice cloning performance in speaker similarity, intelligibility, and naturalness.
- Multi-lingual: support Chinese and English.
## News
**2025/06/16**: 🔥 ZipVoice is released.
## Installation
```
pip install -r requirements.txt
```
## Usage
To generate speech with our pre-trained ZipVoice or ZipVoice-Distill models, use the following commands (Required models will be downloaded from HuggingFace):
### 1. Inference of a single sentence:
```bash
python3 zipvoice/zipvoice_infer.py \
--model-name "zipvoice_distill" \
--prompt-wav prompt.wav \
--prompt-text "I am the transcription of the prompt wav." \
--text "I am the text to be synthesized." \
--res-wav-path result.wav
```
### 2. Inference of a list of sentences:
```bash
python3 zipvoice/zipvoice_infer.py \
--model-name "zipvoice_distill" \
--test-list test.tsv \
--res-dir results/test
```
- `--model-name` can be `zipvoice` or `zipvoice_distill`, which are models before and after distillation, respectively.
- Each line of `test.tsv` is in the format of `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
> **Note:** If you having trouble connecting to HuggingFace, try:
```bash
export HF_ENDPOINT=https://hf-mirror.com
```
## Training Your Own Model
The following steps show how to train a model from scratch on Emilia and LibriTTS datasets, respectively.
### 1. Data Preparation
#### 1.1. Prepare the Emilia dataset
#### 1.2 Prepare the LibriTTS dataset
See [local/prepare_libritts.sh](local/prepare_libritts.sh)
### 2. Training
#### 2.1 Traininig on Emilia
<details>
<summary>Expand to view training steps</summary>
##### 2.1.1 Train the ZipVoice model
- Training:
```bash
export PYTHONPATH=../../:$PYTHONPATH
python3 zipvoice/train_flow.py \
--world-size 8 \
--use-fp16 1 \
--dataset emilia \
--max-duration 500 \
--lr-hours 30000 \
--lr-batches 7500 \
--token-file "data/tokens_emilia.txt" \
--manifest-dir "data/fbank_emilia" \
--num-epochs 11 \
--exp-dir zipvoice/exp_zipvoice
```
- Average the checkpoints to produce the final model:
```bash
export PYTHONPATH=../../:$PYTHONPATH
python3 zipvoice/generate_averaged_model.py \
--epoch 11 \
--avg 4 \
--distill 0 \
--token-file data/tokens_emilia.txt \
--dataset "emilia" \
--exp-dir ./zipvoice/exp_zipvoice
# The generated model is zipvoice/exp_zipvoice/epoch-11-avg-4.pt
```
##### 2.1.2. Train the ZipVoice-Distill model (Optional)
- The first-stage distillation:
```bash
export PYTHONPATH=../../:$PYTHONPATH
python3 zipvoice/train_distill.py \
--world-size 8 \
--use-fp16 1 \
--tensorboard 1 \
--dataset "emilia" \
--base-lr 0.0005 \
--max-duration 500 \
--token-file "data/tokens_emilia.txt" \
--manifest-dir "data/fbank_emilia" \
--teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
--num-updates 60000 \
--distill-stage "first" \
--exp-dir zipvoice/exp_zipvoice_distill_1stage
```
- Average checkpoints for the second-stage initialization:
```bash
export PYTHONPATH=../../:$PYTHONPATH
python3 zipvoice/generate_averaged_model.py \
--iter 60000 \
--avg 7 \
--distill 1 \
--token-file data/tokens_emilia.txt \
--dataset "emilia" \
--exp-dir ./zipvoice/exp_zipvoice_distill_1stage
# The generated model is zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt
```
- The second-stage distillation:
```bash
export PYTHONPATH=../../:$PYTHONPATH
python3 zipvoice/train_distill.py \
--world-size 8 \
--use-fp16 1 \
--tensorboard 1 \
--dataset "emilia" \
--base-lr 0.0001 \
--max-duration 200 \
--token-file "data/tokens_emilia.txt" \
--manifest-dir "data/fbank_emilia" \
--teacher-model zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt \
--num-updates 2000 \
--distill-stage "second" \
--exp-dir zipvoice/exp_zipvoice_distill_new
```
</details>
#### 2.2 Traininig on LibriTTS
<details>
<summary>Expand to view training steps</summary>
##### 2.2.1 Train the ZipVoice model
- Training:
```bash
export PYTHONPATH=../../:$PYTHONPATH
python3 zipvoice/train_flow.py \
--world-size 8 \
--use-fp16 1 \
--dataset libritts \
--max-duration 250 \
--lr-epochs 10 \
--lr-batches 7500 \
--token-file "data/tokens_libritts.txt" \
--manifest-dir "data/fbank_libritts" \
--num-epochs 60 \
--exp-dir zipvoice/exp_zipvoice_libritts
```
- Average the checkpoints to produce the final model:
```bash
export PYTHONPATH=../../:$PYTHONPATH
python3 zipvoice/generate_averaged_model.py \
--epoch 60 \
--avg 10 \
--distill 0 \
--token-file data/tokens_libritts.txt \
--dataset "libritts" \
--exp-dir ./zipvoice/exp_zipvoice_libritts
# The generated model is zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt
```
##### 2.1.2 Train the ZipVoice-Distill model (Optional)
- The first-stage distillation:
```bash
export PYTHONPATH=../../:$PYTHONPATH
python3 zipvoice/train_distill.py \
--world-size 8 \
--use-fp16 1 \
--tensorboard 1 \
--dataset "libritts" \
--base-lr 0.001 \
--max-duration 250 \
--token-file "data/tokens_libritts.txt" \
--manifest-dir "data/fbank_libritts" \
--teacher-model zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt \
--num-epochs 6 \
--distill-stage "first" \
--exp-dir zipvoice/exp_zipvoice_distill_1stage_libritts
```
- Average checkpoints for the second-stage initialization:
```bash
export PYTHONPATH=../../:$PYTHONPATH
python3 ./zipvoice/generate_averaged_model.py \
--epoch 6 \
--avg 3 \
--distill 1 \
--token-file data/tokens_libritts.txt \
--dataset "libritts" \
--exp-dir ./zipvoice/exp_zipvoice_distill_1stage_libritts
# The generated model is zipvoice/exp_zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt
```
- The second-stage distillation:
```bash
export PYTHONPATH=../../:$PYTHONPATH
python3 zipvoice/train_distill.py \
--world-size 8 \
--use-fp16 1 \
--tensorboard 1 \
--dataset "libritts" \
--base-lr 0.001 \
--max-duration 250 \
--token-file "data/tokens_libritts.txt" \
--manifest-dir "data/fbank_libritts" \
--teacher-model zipvoice/exp_zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt \
--num-epochs 6 \
--distill-stage "second" \
--exp-dir zipvoice/exp_zipvoice_distill_libritts
```
- Average checkpoints to produce the final model:
```bash
export PYTHONPATH=../../:$PYTHONPATH
python3 ./zipvoice/generate_averaged_model.py \
--epoch 6 \
--avg 3 \
--distill 1 \
--token-file data/tokens_libritts.txt \
--dataset "libritts" \
--exp-dir ./zipvoice/exp_zipvoice_distill_libritts
# The generated model is ./zipvoice/exp_zipvoice_distill_libritts/epoch-6-avg-3.pt
```
</details>
### 3. Inference with the trained model
#### 3.1 Inference with the model trained on Emilia
<details>
<summary>Expand to view inference commands.</summary>
##### 3.1.1 ZipVoice model before distill:
```bash
export PYTHONPATH=../../:$PYTHONPATH
python3 zipvoice/infer.py \
--checkpoint zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
--distill 0 \
--token-file "data/tokens_emilia.txt" \
--test-list test.tsv \
--res-dir results/test \
--num-step 16 \
--guidance-scale 1
```
##### 3.1.2 ZipVoice-Distill model before distill:
```bash
export PYTHONPATH=../../:$PYTHONPATH
python3 zipvoice/infer.py \
--checkpoint zipvoice/exp_zipvoice_distill/checkpoint-2000.pt \
--distill 1 \
--token-file "data/tokens_emilia.txt" \
--test-list test.tsv \
--res-dir results/test_distill \
--num-step 8 \
--guidance-scale 3
```
</details>
#### 3.2 Inference with the model trained on LibriTTS
<details>
<summary>Expand to view inference commands.</summary>
##### 3.2.1 ZipVoice model before distill:
```bash
export PYTHONPATH=../../:$PYTHONPATH
python3 zipvoice/infer.py \
--checkpoint zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt \
--distill 0 \
--token-file "data/tokens_libritts.txt" \
--test-list test.tsv \
--res-dir results/test_libritts \
--num-step 8 \
--guidance-scale 1 \
--target-rms 1.0 \
--t-shift 0.7
```
##### 3.2.2 ZipVoice-Distill model before distill
```bash
export PYTHONPATH=../../:$PYTHONPATH
python3 zipvoice/infer.py \
--checkpoint zipvoice/exp_zipvoice_distill/epoch-6-avg-3.pt \
--distill 1 \
--token-file "data/tokens_libritts.txt" \
--test-list test.tsv \
--res-dir results/test_distill_libritts \
--num-step 4 \
--guidance-scale 3 \
--target-rms 1.0 \
--t-shift 0.7
```
</details>
### 4. Evaluation on benchmarks
See [local/evaluate.sh](local/evaluate.sh) for details of objective metrics evaluation
on three test sets, i.e., LibriSpeech-PC test-clean, Seed-TTS test-en and Seed-TTS test-zh.
## Citation
```bibtex
@article{zhu-2025-zipvoice,
title={ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching},
author={Han Zhu and Wei Kang and Zengwei Yao and Liyong Guo and Fangjun Kuang and Zhaoqing Li and Weiji Zhuang and Long Lin and Daniel Povey}
journal={arXiv preprint},
year={2025},
}
```

View File

@ -0,0 +1,140 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao,)
# 2024 The Chinese Univ. of HK (authors: Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the LibriTTS dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
from typing import Optional
import torch
from feature import TorchAudioFbank, TorchAudioFbankConfig
from lhotse import CutSet, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=24000,
help="""Sampling rate of the waveform for computing fbank,
the default value for LibriTTS is 24000, waveform files will be
resampled if a different sample rate is provided""",
)
return parser.parse_args()
def compute_fbank_libritts(dataset: Optional[str] = None, sampling_rate: int = 24000):
src_dir = Path("data/manifests_libritts")
output_dir = Path("data/fbank_libritts")
num_jobs = min(32, os.cpu_count())
prefix = "libritts"
suffix = "jsonl.gz"
if dataset is None:
dataset_parts = (
"dev-clean",
"test-clean",
"train-clean-100",
"train-clean-360",
"train-other-500",
)
else:
dataset_parts = dataset.split(" ", -1)
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
config = TorchAudioFbankConfig(
sampling_rate=sampling_rate,
n_mels=100,
n_fft=1024,
hop_length=256,
)
extractor = TorchAudioFbank(config)
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
return
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
if sampling_rate != 24000:
logging.info(f"Resampling waveforms to {sampling_rate}")
cut_set = cut_set.resample(sampling_rate)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_libritts()

View File

@ -0,0 +1,102 @@
export CUDA_VISIBLE_DEVICES="0"
export PYTHONWARNINGS=ignore
export PYTHONPATH=../../:$PYTHONPATH
# Uncomment this if you have trouble connecting to HuggingFace
# export HF_ENDPOINT=https://hf-mirror.com
start_stage=1
end_stage=3
# Models used for SIM-o evaluation.
# SV model wavlm_large_finetune.pth is downloaded from https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification
# SSL model wavlm_large.pt is downloaded from https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt
sv_model_path=model/UniSpeech/wavlm_large_finetune.pth
wavlm_model_path=model/s3prl/wavlm_large.pt
# Models used for UTMOS evaluation.
# wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt
# wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt
utmos_model_path=model/huggingface/utmos/utmos.pt
wav2vec_model_path=model/huggingface/utmos/wav2vec_small.pt
if [ $start_stage -le 1 ] && [ $end_stage -ge 1 ]; then
echo "=====Evaluate for Seed-TTS test-en======="
test_list=testset/test_seedtts_en.tsv
wav_path=results/zipvoice_seedtts_en
echo $wav_path
echo "-----Computing SIM-o-----"
python3 local/evaluate_sim.py \
--sv-model-path ${sv_model_path} \
--ssl-model-path ${wavlm_model_path} \
--eval-path ${wav_path} \
--test-list ${test_list}
echo "-----Computing WER-----"
python3 local/evaluate_wer_seedtts.py \
--test-list ${test_list} \
--wav-path ${wav_path} \
--lang "en"
echo "-----Computing UTSMOS-----"
python3 local/evaluate_utmos.py \
--wav-path ${wav_path} \
--utmos-model-path ${utmos_model_path} \
--ssl-model-path ${wav2vec_model_path}
fi
if [ $start_stage -le 2 ] && [ $end_stage -ge 2 ]; then
echo "=====Evaluate for Seed-TTS test-zh======="
test_list=testset/test_seedtts_zh.tsv
wav_path=results/zipvoice_seedtts_zh
echo $wav_path
echo "-----Computing SIM-o-----"
python3 local/evaluate_sim.py \
--sv-model-path ${sv_model_path} \
--ssl-model-path ${wavlm_model_path} \
--eval-path ${wav_path} \
--test-list ${test_list}
echo "-----Computing WER-----"
python3 local/evaluate_wer_seedtts.py \
--test-list ${test_list} \
--wav-path ${wav_path} \
--lang "zh"
echo "-----Computing UTSMOS-----"
python3 local/evaluate_utmos.py \
--wav-path ${wav_path} \
--utmos-model-path ${utmos_model_path} \
--ssl-model-path ${wav2vec_model_path}
fi
if [ $start_stage -le 3 ] && [ $end_stage -ge 3 ]; then
echo "=====Evaluate for Librispeech test-clean======="
test_list=testset/test_librispeech_pc_test_clean.tsv
wav_path=results/zipvoice_librispeech_test_clean
echo $wav_path
echo "-----Computing SIM-o-----"
python3 local/evaluate_sim.py \
--sv-model-path ${sv_model_path} \
--ssl-model-path ${wavlm_model_path} \
--eval-path ${wav_path} \
--test-list ${test_list}
echo "-----Computing WER-----"
python3 local/evaluate_wer_hubert.py \
--test-list ${test_list} \
--wav-path ${wav_path} \
echo "-----Computing UTSMOS-----"
python3 local/evaluate_utmos.py \
--wav-path ${wav_path} \
--utmos-model-path ${utmos_model_path} \
--ssl-model-path ${wav2vec_model_path}
fi

View File

@ -0,0 +1,508 @@
"""
Calculate pairwise Speaker Similarity betweeen two speech directories.
SV model wavlm_large_finetune.pth is downloaded from
https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification
SSL model wavlm_large.pt is downloaded from
https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt
"""
import argparse
import logging
import os
import librosa
import numpy as np
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
logging.basicConfig(level=logging.INFO)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--eval-path", type=str, help="path of the evaluated speech directory"
)
parser.add_argument(
"--test-list",
type=str,
help="path of the file list that contains the corresponding "
"relationship between the prompt and evaluated speech. "
"The first column is the wav name and the third column is the prompt speech",
)
parser.add_argument(
"--sv-model-path",
type=str,
default="model/UniSpeech/wavlm_large_finetune.pth",
help="path of the wavlm-based ECAPA-TDNN model",
)
parser.add_argument(
"--ssl-model-path",
type=str,
default="model/s3prl/wavlm_large.pt",
help="path of the wavlm SSL model",
)
return parser
class SpeakerSimilarity:
def __init__(
self,
sv_model_path="model/UniSpeech/wavlm_large_finetune.pth",
ssl_model_path="model/s3prl/wavlm_large.pt",
):
"""
Initialize
"""
self.sample_rate = 16000
self.channels = 1
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
logging.info("[Speaker Similarity] Using device: {}".format(self.device))
self.model = ECAPA_TDNN_WAVLLM(
feat_dim=1024,
channels=512,
emb_dim=256,
sr=16000,
ssl_model_path=ssl_model_path,
)
state_dict = torch.load(
sv_model_path, map_location=lambda storage, loc: storage
)
self.model.load_state_dict(state_dict["model"], strict=False)
self.model.to(self.device)
self.model.eval()
def get_embeddings(self, wav_list, dtype="float32"):
"""
Get embeddings
"""
def _load_speech_task(fname, sample_rate):
wav_data, sr = sf.read(fname, dtype=dtype)
if sr != sample_rate:
wav_data = librosa.resample(
wav_data, orig_sr=sr, target_sr=self.sample_rate
)
wav_data = torch.from_numpy(wav_data)
return wav_data
embd_lst = []
for file_path in tqdm(wav_list):
speech = _load_speech_task(file_path, self.sample_rate)
speech = speech.to(self.device)
with torch.no_grad():
embd = self.model([speech])
embd_lst.append(embd)
return embd_lst
def score(
self,
eval_path,
test_list,
dtype="float32",
):
"""
Computes the Speaker Similarity (SIM-o) between two directories of speech files.
Parameters:
- eval_path (str): Path to the directory containing evaluation speech files.
- test_list (str): Path to the file containing the corresponding relationship
between prompt and evaluated speech.
- dtype (str, optional): Data type for loading speech. Default is "float32".
Returns:
- float: The Speaker Similarity (SIM-o) score between the two directories
of speech files.
"""
prompt_wavs = []
eval_wavs = []
with open(test_list, "r") as fr:
lines = fr.readlines()
for line in lines:
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
prompt_wavs.append(prompt_wav)
eval_wavs.append(os.path.join(eval_path, wav_name + ".wav"))
embds_prompt = self.get_embeddings(prompt_wavs, dtype=dtype)
embds_eval = self.get_embeddings(eval_wavs, dtype=dtype)
# Check if embeddings are empty
if len(embds_prompt) == 0:
logging.info("[Speaker Similarity] real set dir is empty, exiting...")
return -1
if len(embds_eval) == 0:
logging.info("[Speaker Similarity] eval set dir is empty, exiting...")
return -1
scores = []
for real_embd, eval_embd in zip(embds_prompt, embds_eval):
scores.append(
torch.nn.functional.cosine_similarity(real_embd, eval_embd, dim=-1)
.detach()
.cpu()
.numpy()
)
return np.mean(scores)
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
""" Res2Conv1d + BatchNorm1d + ReLU
"""
class Res2Conv1dReluBn(nn.Module):
"""
in_channels == out_channels == channels
"""
def __init__(
self,
channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
bias=True,
scale=4,
):
super().__init__()
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
self.scale = scale
self.width = channels // scale
self.nums = scale if scale == 1 else scale - 1
self.convs = []
self.bns = []
for i in range(self.nums):
self.convs.append(
nn.Conv1d(
self.width,
self.width,
kernel_size,
stride,
padding,
dilation,
bias=bias,
)
)
self.bns.append(nn.BatchNorm1d(self.width))
self.convs = nn.ModuleList(self.convs)
self.bns = nn.ModuleList(self.bns)
def forward(self, x):
out = []
spx = torch.split(x, self.width, 1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
# Order: conv -> relu -> bn
sp = self.convs[i](sp)
sp = self.bns[i](F.relu(sp))
out.append(sp)
if self.scale != 1:
out.append(spx[self.nums])
out = torch.cat(out, dim=1)
return out
""" Conv1d + BatchNorm1d + ReLU
"""
class Conv1dReluBn(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
bias=True,
):
super().__init__()
self.conv = nn.Conv1d(
in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
)
self.bn = nn.BatchNorm1d(out_channels)
def forward(self, x):
return self.bn(F.relu(self.conv(x)))
""" The SE connection of 1D case.
"""
class SE_Connect(nn.Module):
def __init__(self, channels, se_bottleneck_dim=128):
super().__init__()
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
def forward(self, x):
out = x.mean(dim=2)
out = F.relu(self.linear1(out))
out = torch.sigmoid(self.linear2(out))
out = x * out.unsqueeze(2)
return out
""" SE-Res2Block of the ECAPA-TDNN architecture.
"""
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
# return nn.Sequential(
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
# SE_Connect(channels)
# )
class SE_Res2Block(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
scale,
se_bottleneck_dim,
):
super().__init__()
self.Conv1dReluBn1 = Conv1dReluBn(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
self.Res2Conv1dReluBn = Res2Conv1dReluBn(
out_channels, kernel_size, stride, padding, dilation, scale=scale
)
self.Conv1dReluBn2 = Conv1dReluBn(
out_channels, out_channels, kernel_size=1, stride=1, padding=0
)
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
)
def forward(self, x):
residual = x
if self.shortcut:
residual = self.shortcut(x)
x = self.Conv1dReluBn1(x)
x = self.Res2Conv1dReluBn(x)
x = self.Conv1dReluBn2(x)
x = self.SE_Connect(x)
return x + residual
""" Attentive weighted mean and standard deviation pooling.
"""
class AttentiveStatsPool(nn.Module):
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
super().__init__()
self.global_context_att = global_context_att
# Use Conv1d with stride == 1 rather than Linear,
# then we don't need to transpose inputs.
if global_context_att:
self.linear1 = nn.Conv1d(
in_dim * 3, attention_channels, kernel_size=1
) # equals W and b in the paper
else:
self.linear1 = nn.Conv1d(
in_dim, attention_channels, kernel_size=1
) # equals W and b in the paper
self.linear2 = nn.Conv1d(
attention_channels, in_dim, kernel_size=1
) # equals V and k in the paper
def forward(self, x):
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt(
torch.var(x, dim=-1, keepdim=True) + 1e-10
).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1)
else:
x_in = x
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
alpha = torch.tanh(self.linear1(x_in))
# alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
std = torch.sqrt(residuals.clamp(min=1e-9))
return torch.cat([mean, std], dim=1)
class ECAPA_TDNN_WAVLLM(nn.Module):
def __init__(
self,
feat_dim=80,
channels=512,
emb_dim=192,
global_context_att=False,
sr=16000,
ssl_model_path=None,
):
super().__init__()
self.sr = sr
if ssl_model_path is None:
self.feature_extract = torch.hub.load("s3prl/s3prl", "wavlm_large")
else:
self.feature_extract = torch.hub.load(
os.path.dirname(ssl_model_path),
"wavlm_local",
source="local",
ckpt=ssl_model_path,
)
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
):
self.feature_extract.model.encoder.layers[
23
].self_attn.fp32_attention = False
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
):
self.feature_extract.model.encoder.layers[
11
].self_attn.fp32_attention = False
self.feat_num = self.get_feat_num()
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
self.instance_norm = nn.InstanceNorm1d(feat_dim)
# self.channels = [channels] * 4 + [channels * 3]
self.channels = [channels] * 4 + [1536]
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
self.layer2 = SE_Res2Block(
self.channels[0],
self.channels[1],
kernel_size=3,
stride=1,
padding=2,
dilation=2,
scale=8,
se_bottleneck_dim=128,
)
self.layer3 = SE_Res2Block(
self.channels[1],
self.channels[2],
kernel_size=3,
stride=1,
padding=3,
dilation=3,
scale=8,
se_bottleneck_dim=128,
)
self.layer4 = SE_Res2Block(
self.channels[2],
self.channels[3],
kernel_size=3,
stride=1,
padding=4,
dilation=4,
scale=8,
se_bottleneck_dim=128,
)
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
cat_channels = channels * 3
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
self.pooling = AttentiveStatsPool(
self.channels[-1],
attention_channels=128,
global_context_att=global_context_att,
)
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
def get_feat_num(self):
self.feature_extract.eval()
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
with torch.no_grad():
features = self.feature_extract(wav)
select_feature = features["hidden_states"]
if isinstance(select_feature, (list, tuple)):
return len(select_feature)
else:
return 1
def get_feat(self, x):
with torch.no_grad():
x = self.feature_extract([sample for sample in x])
x = x["hidden_states"]
if isinstance(x, (list, tuple)):
x = torch.stack(x, dim=0)
else:
x = x.unsqueeze(0)
norm_weights = (
F.softmax(self.feature_weight, dim=-1)
.unsqueeze(-1)
.unsqueeze(-1)
.unsqueeze(-1)
)
x = (norm_weights * x).sum(dim=0)
x = torch.transpose(x, 1, 2) + 1e-6
x = self.instance_norm(x)
return x
def forward(self, x):
x = self.get_feat(x)
out1 = self.layer1(x)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out = torch.cat([out2, out3, out4], dim=1)
out = F.relu(self.conv(out))
out = self.bn(self.pooling(out))
out = self.linear(out)
return out
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
SIM = SpeakerSimilarity(
sv_model_path=args.sv_model_path, ssl_model_path=args.ssl_model_path
)
score = SIM.score(args.eval_path, args.test_list)
logging.info(f"SIM-o score: {score:.3f}")

View File

@ -0,0 +1,294 @@
"""
Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system
adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo
# Download model checkpoints
wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt
wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt
"""
import argparse
import logging
import os
import fairseq
import librosa
import numpy as np
import pytorch_lightning as pl
import soundfile as sf
import torch
import torch.nn as nn
from tqdm import tqdm
logging.basicConfig(level=logging.INFO)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--wav-path", type=str, help="path of the evaluated speech directory"
)
parser.add_argument(
"--utmos-model-path",
type=str,
default="model/huggingface/utmos/utmos.pt",
help="path of the UTMOS model",
)
parser.add_argument(
"--ssl-model-path",
type=str,
default="model/huggingface/utmos/wav2vec_small.pt",
help="path of the wav2vec SSL model",
)
return parser
class UTMOSScore:
"""Predicting score for each audio clip."""
def __init__(self, utmos_model_path, ssl_model_path):
self.sample_rate = 16000
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
self.model = (
BaselineLightningModule.load_from_checkpoint(
utmos_model_path, ssl_model_path=ssl_model_path
)
.eval()
.to(self.device)
)
def score(self, wavs: torch.Tensor) -> torch.Tensor:
"""
Args:
wavs: waveforms to be evaluated. When len(wavs) == 1 or 2,
the model processes the input as a single audio clip. The model
performs batch processing when len(wavs) == 3.
"""
if len(wavs.shape) == 1:
out_wavs = wavs.unsqueeze(0).unsqueeze(0)
elif len(wavs.shape) == 2:
out_wavs = wavs.unsqueeze(0)
elif len(wavs.shape) == 3:
out_wavs = wavs
else:
raise ValueError("Dimension of input tensor needs to be <= 3.")
bs = out_wavs.shape[0]
batch = {
"wav": out_wavs,
"domains": torch.zeros(bs, dtype=torch.int).to(self.device),
"judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288,
}
with torch.no_grad():
output = self.model(batch)
return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3
def score_dir(self, dir, dtype="float32"):
def _load_speech_task(fname, sample_rate):
wav_data, sr = sf.read(fname, dtype=dtype)
if sr != sample_rate:
wav_data = librosa.resample(
wav_data, orig_sr=sr, target_sr=self.sample_rate
)
wav_data = torch.from_numpy(wav_data)
return wav_data
score_lst = []
for fname in tqdm(os.listdir(dir)):
speech = _load_speech_task(os.path.join(dir, fname), self.sample_rate)
speech = speech.to(self.device)
with torch.no_grad():
score = self.score(speech)
score_lst.append(score.item())
return np.mean(score_lst)
def load_ssl_model(ckpt_path="wav2vec_small.pt"):
SSL_OUT_DIM = 768
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[ckpt_path]
)
ssl_model = model[0]
ssl_model.remove_pretraining_modules()
return SSL_model(ssl_model, SSL_OUT_DIM)
class BaselineLightningModule(pl.LightningModule):
def __init__(self, ssl_model_path):
super().__init__()
self.construct_model(ssl_model_path)
self.save_hyperparameters()
def construct_model(self, ssl_model_path):
self.feature_extractors = nn.ModuleList(
[
load_ssl_model(ckpt_path=ssl_model_path),
DomainEmbedding(3, 128),
]
)
output_dim = sum(
[
feature_extractor.get_output_dim()
for feature_extractor in self.feature_extractors
]
)
output_layers = [
LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim)
]
output_dim = output_layers[-1].get_output_dim()
output_layers.append(
Projection(
hidden_dim=2048,
activation=torch.nn.ReLU(),
range_clipping=False,
input_dim=output_dim,
)
)
self.output_layers = nn.ModuleList(output_layers)
def forward(self, inputs):
outputs = {}
for feature_extractor in self.feature_extractors:
outputs.update(feature_extractor(inputs))
x = outputs
for output_layer in self.output_layers:
x = output_layer(x, inputs)
return x
class SSL_model(nn.Module):
def __init__(self, ssl_model, ssl_out_dim) -> None:
super(SSL_model, self).__init__()
self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
def forward(self, batch):
wav = batch["wav"]
wav = wav.squeeze(1) # [batches, wav_len]
res = self.ssl_model(wav, mask=False, features_only=True)
x = res["x"]
return {"ssl-feature": x}
def get_output_dim(self):
return self.ssl_out_dim
class DomainEmbedding(nn.Module):
def __init__(self, n_domains, domain_dim) -> None:
super().__init__()
self.embedding = nn.Embedding(n_domains, domain_dim)
self.output_dim = domain_dim
def forward(self, batch):
return {"domain-feature": self.embedding(batch["domains"])}
def get_output_dim(self):
return self.output_dim
class LDConditioner(nn.Module):
"""
Conditions ssl output by listener embedding
"""
def __init__(self, input_dim, judge_dim, num_judges=None):
super().__init__()
self.input_dim = input_dim
self.judge_dim = judge_dim
self.num_judges = num_judges
assert num_judges != None
self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
# concat [self.output_layer, phoneme features]
self.decoder_rnn = nn.LSTM(
input_size=self.input_dim + self.judge_dim,
hidden_size=512,
num_layers=1,
batch_first=True,
bidirectional=True,
) # linear?
self.out_dim = self.decoder_rnn.hidden_size * 2
def get_output_dim(self):
return self.out_dim
def forward(self, x, batch):
judge_ids = batch["judge_id"]
if "phoneme-feature" in x.keys():
concatenated_feature = torch.cat(
(
x["ssl-feature"],
x["phoneme-feature"]
.unsqueeze(1)
.expand(-1, x["ssl-feature"].size(1), -1),
),
dim=2,
)
else:
concatenated_feature = x["ssl-feature"]
if "domain-feature" in x.keys():
concatenated_feature = torch.cat(
(
concatenated_feature,
x["domain-feature"]
.unsqueeze(1)
.expand(-1, concatenated_feature.size(1), -1),
),
dim=2,
)
if judge_ids != None:
concatenated_feature = torch.cat(
(
concatenated_feature,
self.judge_embedding(judge_ids)
.unsqueeze(1)
.expand(-1, concatenated_feature.size(1), -1),
),
dim=2,
)
decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
return decoder_output
class Projection(nn.Module):
def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
super(Projection, self).__init__()
self.range_clipping = range_clipping
output_dim = 1
if range_clipping:
self.proj = nn.Tanh()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
activation,
nn.Dropout(0.3),
nn.Linear(hidden_dim, output_dim),
)
self.output_dim = output_dim
def forward(self, x, batch):
output = self.net(x)
# range clipping
if self.range_clipping:
return self.proj(output) * 2.0 + 3
else:
return output
def get_output_dim(self):
return self.output_dim
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
UTMOS = UTMOSScore(
utmos_model_path=args.utmos_model_path, ssl_model_path=args.ssl_model_path
)
score = UTMOS.score_dir(args.wav_path)
logging.info(f"UTMOS score: {score:.2f}")

View File

@ -0,0 +1,172 @@
"""
Calculate WER with Hubert models.
"""
import argparse
import os
import re
from pathlib import Path
import librosa
import numpy as np
import soundfile as sf
import torch
from jiwer import compute_measures
from tqdm import tqdm
from transformers import pipeline
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
parser.add_argument(
"--decode-path",
type=str,
default=None,
help="path of the output file of WER information",
)
parser.add_argument(
"--model-path",
type=str,
default=None,
help="path of the local hubert model, e.g., model/huggingface/hubert-large-ls960-ft",
)
parser.add_argument(
"--test-list",
type=str,
default="test.tsv",
help="path of the transcript tsv file, where the first column "
"is the wav name and the last column is the transcript",
)
parser.add_argument(
"--batch-size", type=int, default=16, help="decoding batch size"
)
return parser
def post_process(text: str):
text = text.replace("", "'")
text = text.replace("", "'")
text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def process_one(hypo, truth):
truth = post_process(truth)
hypo = post_process(hypo)
measures = compute_measures(truth, hypo)
word_num = len(truth.split(" "))
wer = measures["wer"]
subs = measures["substitutions"]
dele = measures["deletions"]
inse = measures["insertions"]
return (truth, hypo, wer, subs, dele, inse, word_num)
class SpeechEvalDataset(torch.utils.data.Dataset):
def __init__(self, wav_path: str, test_list: str):
super().__init__()
self.wav_name = []
self.wav_paths = []
self.transcripts = []
with Path(test_list).open("r", encoding="utf8") as f:
meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
for item in meta:
self.wav_name.append(item[0])
self.wav_paths.append(Path(wav_path, item[0] + ".wav"))
self.transcripts.append(item[-1])
def __len__(self):
return len(self.wav_paths)
def __getitem__(self, index: int):
wav, sampling_rate = sf.read(self.wav_paths[index])
item = {
"array": librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000),
"sampling_rate": 16000,
"reference": self.transcripts[index],
"wav_name": self.wav_name[index],
}
return item
def main(test_list, wav_path, model_path, decode_path, batch_size, device):
if model_path is not None:
pipe = pipeline(
"automatic-speech-recognition",
model=model_path,
device=device,
tokenizer=model_path,
)
else:
pipe = pipeline(
"automatic-speech-recognition",
model="facebook/hubert-large-ls960-ft",
device=device,
)
dataset = SpeechEvalDataset(wav_path, test_list)
bar = tqdm(
pipe(
dataset,
generate_kwargs={"language": "english", "task": "transcribe"},
batch_size=batch_size,
),
total=len(dataset),
)
wers = []
inses = []
deles = []
subses = []
word_nums = 0
if decode_path:
decode_dir = os.path.dirname(decode_path)
if not os.path.exists(decode_dir):
os.makedirs(decode_dir)
fout = open(decode_path, "w")
for out in bar:
wav_name = out["wav_name"][0]
transcription = post_process(out["text"].strip())
text_ref = post_process(out["reference"][0].strip())
truth, hypo, wer, subs, dele, inse, word_num = process_one(
transcription, text_ref
)
if decode_path:
fout.write(f"{wav_name}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
wers.append(float(wer))
inses.append(float(inse))
deles.append(float(dele))
subses.append(float(subs))
word_nums += word_num
wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
subs = round(np.mean(subses) * 100, 3)
dele = round(np.mean(deles) * 100, 3)
inse = round(np.mean(inses) * 100, 3)
print(f"WER: {wer}%\n")
if decode_path:
fout.write(f"WER: {wer}%\n")
fout.flush()
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
if torch.cuda.is_available():
device = torch.device("cuda", 0)
else:
device = torch.device("cpu")
main(
args.test_list,
args.wav_path,
args.model_path,
args.decode_path,
args.batch_size,
device,
)

View File

@ -0,0 +1,181 @@
"""
Calculate WER with Whisper-large-v3 or Paraformer models,
following Seed-TTS https://github.com/BytedanceSpeech/seed-tts-eval
"""
import argparse
import os
import string
import numpy as np
import scipy
import soundfile as sf
import torch
import zhconv
from funasr import AutoModel
from jiwer import compute_measures
from tqdm import tqdm
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from zhon.hanzi import punctuation
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
parser.add_argument(
"--decode-path",
type=str,
default=None,
help="path of the output file of WER information",
)
parser.add_argument(
"--model-path",
type=str,
default=None,
help="path of the local whisper and paraformer model, "
"e.g., whisper: model/huggingface/whisper-large-v3/, "
"paraformer: model/huggingface/paraformer-zh/",
)
parser.add_argument(
"--test-list",
type=str,
default="test.tsv",
help="path of the transcript tsv file, where the first column "
"is the wav name and the last column is the transcript",
)
parser.add_argument("--lang", type=str, help="decoded language, zh or en")
return parser
def load_en_model(model_path):
if model_path is None:
model_path = "openai/whisper-large-v3"
processor = WhisperProcessor.from_pretrained(model_path)
model = WhisperForConditionalGeneration.from_pretrained(model_path)
return processor, model
def load_zh_model(model_path):
if model_path is None:
model_path = "paraformer-zh"
model = AutoModel(model=model_path)
return model
def process_one(hypo, truth, lang):
punctuation_all = punctuation + string.punctuation
for x in punctuation_all:
if x == "'":
continue
truth = truth.replace(x, "")
hypo = hypo.replace(x, "")
truth = truth.replace(" ", " ")
hypo = hypo.replace(" ", " ")
if lang == "zh":
truth = " ".join([x for x in truth])
hypo = " ".join([x for x in hypo])
elif lang == "en":
truth = truth.lower()
hypo = hypo.lower()
else:
raise NotImplementedError
measures = compute_measures(truth, hypo)
word_num = len(truth.split(" "))
wer = measures["wer"]
subs = measures["substitutions"]
dele = measures["deletions"]
inse = measures["insertions"]
return (truth, hypo, wer, subs, dele, inse, word_num)
def main(test_list, wav_path, model_path, decode_path, lang, device):
if lang == "en":
processor, model = load_en_model(model_path)
model.to(device)
elif lang == "zh":
model = load_zh_model(model_path)
params = []
for line in open(test_list).readlines():
line = line.strip()
items = line.split("\t")
wav_name, text_ref = items[0], items[-1]
file_path = os.path.join(wav_path, wav_name + ".wav")
assert os.path.exists(file_path), f"{file_path}"
params.append((file_path, text_ref))
wers = []
inses = []
deles = []
subses = []
word_nums = 0
if decode_path:
decode_dir = os.path.dirname(decode_path)
if not os.path.exists(decode_dir):
os.makedirs(decode_dir)
fout = open(decode_path, "w")
for wav_path, text_ref in tqdm(params):
if lang == "en":
wav, sr = sf.read(wav_path)
if sr != 16000:
wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr))
input_features = processor(
wav, sampling_rate=16000, return_tensors="pt"
).input_features
input_features = input_features.to(device)
forced_decoder_ids = processor.get_decoder_prompt_ids(
language="english", task="transcribe"
)
predicted_ids = model.generate(
input_features, forced_decoder_ids=forced_decoder_ids
)
transcription = processor.batch_decode(
predicted_ids, skip_special_tokens=True
)[0]
elif lang == "zh":
res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True)
transcription = res[0]["text"]
transcription = zhconv.convert(transcription, "zh-cn")
truth, hypo, wer, subs, dele, inse, word_num = process_one(
transcription, text_ref, lang
)
if decode_path:
fout.write(f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
wers.append(float(wer))
inses.append(float(inse))
deles.append(float(dele))
subses.append(float(subs))
word_nums += word_num
wer_avg = round(np.mean(wers) * 100, 3)
wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
subs = round(np.mean(subses) * 100, 3)
dele = round(np.mean(deles) * 100, 3)
inse = round(np.mean(inses) * 100, 3)
print(f"Seed-TTS WER: {wer_avg}%\n")
print(f"WER: {wer}%\n")
if decode_path:
fout.write(f"SeedTTS WER: {wer_avg}%\n")
fout.write(f"WER: {wer}%\n")
fout.flush()
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
if torch.cuda.is_available():
device = torch.device("cuda", 0)
else:
device = torch.device("cpu")
main(
args.test_list,
args.wav_path,
args.model_path,
args.decode_path,
args.lang,
device,
)

View File

@ -0,0 +1,135 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Union
import numpy as np
import torch
import torch.nn as nn
import torchaudio
from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.utils import Seconds, compute_num_frames
class MelSpectrogramFeatures(nn.Module):
def __init__(
self,
sampling_rate=24000,
n_mels=100,
n_fft=1024,
hop_length=256,
):
super().__init__()
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sampling_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
center=True,
power=1,
)
def forward(self, inp):
assert len(inp.shape) == 2
mel = self.mel_spec(inp)
logmel = mel.clamp(min=1e-7).log()
return logmel
@dataclass
class TorchAudioFbankConfig:
sampling_rate: int
n_mels: int
n_fft: int
hop_length: int
@register_extractor
class TorchAudioFbank(FeatureExtractor):
name = "TorchAudioFbank"
config_type = TorchAudioFbankConfig
def __init__(self, config):
super().__init__(config=config)
def _feature_fn(self, sample):
fbank = MelSpectrogramFeatures(
sampling_rate=self.config.sampling_rate,
n_mels=self.config.n_mels,
n_fft=self.config.n_fft,
hop_length=self.config.hop_length,
)
return fbank(sample)
@property
def device(self) -> Union[str, torch.device]:
return self.config.device
def feature_dim(self, sampling_rate: int) -> int:
return self.config.n_mels
def extract(
self,
samples: Union[np.ndarray, torch.Tensor],
sampling_rate: int,
) -> Union[np.ndarray, torch.Tensor]:
# Check for sampling rate compatibility.
expected_sr = self.config.sampling_rate
assert sampling_rate == expected_sr, (
f"Mismatched sampling rate: extractor expects {expected_sr}, "
f"got {sampling_rate}"
)
is_numpy = False
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
is_numpy = True
if len(samples.shape) == 1:
samples = samples.unsqueeze(0)
assert samples.ndim == 2, samples.shape
assert samples.shape[0] == 1, samples.shape
mel = self._feature_fn(samples).squeeze().t()
assert mel.ndim == 2, mel.shape
assert mel.shape[1] == self.config.n_mels, mel.shape
num_frames = compute_num_frames(
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
)
if mel.shape[0] > num_frames:
mel = mel[:num_frames]
elif mel.shape[0] < num_frames:
mel = mel.unsqueeze(0)
mel = torch.nn.functional.pad(
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
).squeeze(0)
if is_numpy:
return mel.cpu().numpy()
else:
return mel
@property
def frame_shift(self) -> Seconds:
return self.config.hop_length / self.config.sampling_rate

View File

@ -0,0 +1,88 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=0
stop_stage=5
sampling_rate=24000
nj=32
dl_dir=$PWD/download
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/LibriTTS,
# you can create a symlink
#
# ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS
#
if [ ! -d $dl_dir/LibriTTS ]; then
lhotse download libritts $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare LibriTTS manifest"
# We assume that you have downloaded the LibriTTS corpus
# to $dl_dir/LibriTTS
mkdir -p data/manifests_libritts
if [ ! -e data/manifests_libritts/.libritts.done ]; then
lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests_libritts
touch data/manifests_libritts/.libritts.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Compute Fbank for LibriTTS"
mkdir -p data/fbank
if [ ! -e data/fbank_libritts/.libritts.done ]; then
./local/compute_fbank_libritts.py --sampling-rate $sampling_rate
touch data/fbank_libritts/.libritts.done
fi
# Here we shuffle and combine the train-clean-100, train-clean-360 and
# train-other-500 together to form the training set.
if [ ! -f data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz ]; then
cat <(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-100.jsonl.gz) \
<(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-360.jsonl.gz) \
<(gunzip -c data/fbank_libritts/libritts_cuts_train-other-500.jsonl.gz) | \
shuf | gzip -c > data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz
fi
if [ ! -f data/fbank_libritts/libritts_cuts_train-clean-460.jsonl.gz ]; then
cat <(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-100.jsonl.gz) \
<(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-360.jsonl.gz) | \
shuf | gzip -c > data/fbank_libritts/libritts_cuts_train-clean-460.jsonl.gz
fi
if [ ! -e data/fbank_libritts/.libritts-validated.done ]; then
log "Validating data/fbank for LibriTTS"
./local/validate_manifest.py \
data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz
touch data/fbank_libritts/.libritts-validated.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 4: Generate token file"
if [ ! -e data/tokens_libritts.txt ]; then
./local/prepare_token_file_libritts.py --tokens data/tokens_libritts.txt
fi
fi

View File

@ -0,0 +1,90 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file generates the file that maps tokens to IDs.
"""
import argparse
import logging
from pathlib import Path
from typing import List
from piper_phonemize import get_espeak_map
from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--tokens",
type=Path,
default=Path("data/tokens_emilia.txt"),
help="Path to the dict that maps the text tokens to IDs",
)
parser.add_argument(
"--pinyin",
type=Path,
default=Path("local/pinyin.txt"),
help="Path to the all unique pinyin",
)
return parser.parse_args()
def get_pinyin_tokens(pinyin: Path) -> List[str]:
phones = set()
with open(pinyin, "r") as f:
for line in f:
x = line.strip()
initial = to_initials(x, strict=False)
# don't want to share tokens with espeak tokens, so use tone3 style
finals = to_finals_tone3(x, strict=False, neutral_tone_with_five=True)
if initial != "":
# don't want to share tokens with espeak tokens, so add a '0' after each initial
phones.add(initial + "0")
if finals != "":
phones.add(finals)
return sorted(phones)
def get_token2id(args):
"""Get a dict that maps token to IDs, and save it to the given filename."""
all_tokens = get_espeak_map() # token: [token_id]
all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()}
# sort by token_id
all_tokens = sorted(all_tokens.items(), key=lambda x: x[1])
all_pinyin = get_pinyin_tokens(args.pinyin)
with open(args.tokens, "w", encoding="utf-8") as f:
for token, token_id in all_tokens:
f.write(f"{token} {token_id}\n")
num_espeak_tokens = len(all_tokens)
for i, pinyin in enumerate(all_pinyin):
f.write(f"{pinyin} {num_espeak_tokens + i}\n")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
get_token2id(args)

View File

@ -0,0 +1,31 @@
import re
from collections import Counter
from lhotse import load_manifest_lazy
def prepare_tokens(manifest_file, token_file):
counter = Counter()
manifest = load_manifest_lazy(manifest_file)
for cut in manifest:
line = re.sub(r"\s+", " ", cut.supervisions[0].text)
counter.update(line)
unique_chars = set(counter.keys())
if "_" in unique_chars:
unique_chars.remove("_")
sorted_chars = sorted(unique_chars, key=lambda char: counter[char], reverse=True)
result = ["_"] + sorted_chars
with open(token_file, "w", encoding="utf-8") as file:
for index, char in enumerate(result):
file.write(f"{char} {index}\n")
if __name__ == "__main__":
manifest_file = "data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz"
output_token_file = "data/tokens_libritts.txt"
prepare_tokens(manifest_file, output_token_file)

View File

@ -0,0 +1,192 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao,
# Zengrui Jin,
# Wei Kang)
# 2024 Tsinghua University (authors: Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file reads the texts in given manifest and save the new cuts with phoneme tokens.
"""
import argparse
import glob
import logging
import re
from concurrent.futures import ProcessPoolExecutor as Pool
from pathlib import Path
from typing import List
import jieba
from lhotse import load_manifest_lazy
from tokenizer import Tokenizer, is_alphabet, is_chinese, is_hangul, is_japanese
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--subset",
type=str,
help="Subset of emilia, (ZH, EN, etc.)",
)
parser.add_argument(
"--jobs",
type=int,
default=50,
help="Number of jobs to processing.",
)
parser.add_argument(
"--source-dir",
type=str,
default="data/manifests_emilia/splits",
help="The source directory of manifest files.",
)
parser.add_argument(
"--dest-dir",
type=str,
help="The destination directory of manifest files.",
)
return parser.parse_args()
def tokenize_by_CJK_char(line: str) -> List[str]:
"""
Tokenize a line of text with CJK char.
Note: All return characters will be upper case.
Example:
input = "你好世界是 hello world 的中文"
output = [, , , , , HELLO, WORLD, , , ]
Args:
line:
The input text.
Return:
A new string tokenize by CJK char.
"""
# The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
pattern = re.compile(
r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])"
)
chars = pattern.split(line.strip().upper())
char_list = []
for w in chars:
if w.strip():
char_list += w.strip().split()
return char_list
def prepare_tokens_emilia(file_name: str, input_dir: Path, output_dir: Path):
logging.info(f"Processing {file_name}")
if (output_dir / file_name).is_file():
logging.info(f"{file_name} exists, skipping.")
return
jieba.setLogLevel(logging.INFO)
tokenizer = Tokenizer()
def _prepare_cut(cut):
# Each cut only contains one supervision
assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
text = cut.supervisions[0].text
cut.supervisions[0].normalized_text = text
tokens = tokenizer.texts_to_tokens([text])[0]
cut.tokens = tokens
return cut
def _filter_cut(cut):
text = cut.supervisions[0].text
duration = cut.supervisions[0].duration
chinese = []
english = []
# only contains chinese and space and alphabets
clean_chars = []
for x in text:
if is_hangul(x):
logging.info(f"Delete cut with text containing Korean : {text}")
return False
if is_japanese(x):
logging.info(f"Delete cut with text containing Japanese : {text}")
return False
if is_chinese(x):
chinese.append(x)
clean_chars.append(x)
if is_alphabet(x):
english.append(x)
clean_chars.append(x)
if x == " ":
clean_chars.append(x)
if len(english) + len(chinese) == 0:
logging.info(f"Delete cut with text has no valid chars : {text}")
return False
words = tokenize_by_CJK_char("".join(clean_chars))
for i in range(len(words) - 10):
if words[i : i + 10].count(words[i]) == 10:
logging.info(f"Delete cut with text with too much repeats : {text}")
return False
# word speed, 20 - 600 / minute
if duration < len(words) / 600 * 60 or duration > len(words) / 20 * 60:
logging.info(
f"Delete cut with audio text mismatch, duration : {duration}s, words : {len(words)}, text : {text}"
)
return False
return True
try:
cut_set = load_manifest_lazy(input_dir / file_name)
cut_set = cut_set.filter(_filter_cut)
cut_set = cut_set.map(_prepare_cut)
cut_set.to_file(output_dir / file_name)
except Exception as e:
logging.error(f"Manifest {file_name} failed with error: {e}")
raise
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
input_dir = Path(args.source_dir)
output_dir = Path(args.dest_dir)
output_dir.mkdir(parents=True, exist_ok=True)
cut_files = glob.glob(f"{args.source_dir}/emilia_cuts_{args.subset}.*.jsonl.gz")
with Pool(max_workers=args.jobs) as pool:
futures = [
pool.submit(
prepare_tokens_emilia, filename.split("/")[-1], input_dir, output_dir
)
for filename in cut_files
]
for f in futures:
try:
f.result()
f.done()
except Exception as e:
logging.error(f"Future failed with error: {e}")
logging.info("Processing done.")

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks the following assumptions of the generated manifest:
- Single supervision per cut
We will add more checks later if needed.
Usage example:
python3 ./local/validate_manifest.py \
./data/spectrogram/ljspeech_cuts_all.jsonl.gz
"""
import argparse
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset.speech_synthesis import validate_for_tts
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest",
type=Path,
help="Path to the manifest file",
)
return parser.parse_args()
def main():
args = get_args()
manifest = args.manifest
logging.info(f"Validating {manifest}")
assert manifest.is_file(), f"{manifest} does not exist"
cut_set = load_manifest_lazy(manifest)
assert isinstance(cut_set, CutSet), type(cut_set)
validate_for_tts(cut_set)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,142 @@
# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from typing import Any, Dict, Optional, Union
import torch
import torch.nn as nn
from lhotse.dataset.sampling.base import CutSampler
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
# use duck typing for LRScheduler since we have different possibilities, see
# our class LRScheduler.
LRSchedulerType = object
def save_checkpoint(
filename: Path,
model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None,
model_ema: Optional[nn.Module] = None,
params: Optional[Dict[str, Any]] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
) -> None:
"""Save training information to a file.
Args:
filename:
The checkpoint filename.
model:
The model to be saved. We only save its `state_dict()`.
model_avg:
The stored model averaged from the start of training.
model_ema:
The EMA version of model.
params:
User defined parameters, e.g., epoch, loss.
optimizer:
The optimizer to be saved. We only save its `state_dict()`.
scheduler:
The scheduler to be saved. We only save its `state_dict()`.
scalar:
The GradScaler to be saved. We only save its `state_dict()`.
sampler:
The sampler used in the labeled training dataset. We only
save its `state_dict()`.
rank:
Used in DDP. We save checkpoint only for the node whose
rank is 0.
Returns:
Return None.
"""
if rank != 0:
return
logging.info(f"Saving checkpoint to {filename}")
if isinstance(model, DDP):
model = model.module
checkpoint = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"scheduler": scheduler.state_dict() if scheduler is not None else None,
"grad_scaler": scaler.state_dict() if scaler is not None else None,
"sampler": sampler.state_dict() if sampler is not None else None,
}
if model_avg is not None:
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
if model_ema is not None:
checkpoint["model_ema"] = model_ema.to(torch.float32).state_dict()
if params:
for k, v in params.items():
assert k not in checkpoint
checkpoint[k] = v
torch.save(checkpoint, filename)
def load_checkpoint(
filename: Path,
model: Optional[nn.Module] = None,
model_avg: Optional[nn.Module] = None,
model_ema: Optional[nn.Module] = None,
strict: bool = False,
) -> Dict[str, Any]:
logging.info(f"Loading checkpoint from {filename}")
checkpoint = torch.load(filename, map_location="cpu")
if model is not None:
if next(iter(checkpoint["model"])).startswith("module."):
logging.info("Loading checkpoint saved by DDP")
dst_state_dict = model.state_dict()
src_state_dict = checkpoint["model"]
for key in dst_state_dict.keys():
src_key = "{}.{}".format("module", key)
dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict, strict=strict)
else:
logging.info("Loading checkpoint")
model.load_state_dict(checkpoint["model"], strict=strict)
checkpoint.pop("model")
if model_avg is not None and "model_avg" in checkpoint:
logging.info("Loading averaged model")
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
checkpoint.pop("model_avg")
if model_ema is not None and "model_ema" in checkpoint:
logging.info("Loading ema model")
model_ema.load_state_dict(checkpoint["model_ema"], strict=strict)
checkpoint.pop("model_ema")
return checkpoint

View File

@ -0,0 +1,135 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Union
import numpy as np
import torch
import torch.nn as nn
import torchaudio
from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.utils import Seconds, compute_num_frames
class MelSpectrogramFeatures(nn.Module):
def __init__(
self,
sampling_rate=24000,
n_mels=100,
n_fft=1024,
hop_length=256,
):
super().__init__()
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sampling_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
center=True,
power=1,
)
def forward(self, inp):
assert len(inp.shape) == 2
mel = self.mel_spec(inp)
logmel = mel.clamp(min=1e-7).log()
return logmel
@dataclass
class TorchAudioFbankConfig:
sampling_rate: int
n_mels: int
n_fft: int
hop_length: int
@register_extractor
class TorchAudioFbank(FeatureExtractor):
name = "TorchAudioFbank"
config_type = TorchAudioFbankConfig
def __init__(self, config):
super().__init__(config=config)
def _feature_fn(self, sample):
fbank = MelSpectrogramFeatures(
sampling_rate=self.config.sampling_rate,
n_mels=self.config.n_mels,
n_fft=self.config.n_fft,
hop_length=self.config.hop_length,
)
return fbank(sample)
@property
def device(self) -> Union[str, torch.device]:
return self.config.device
def feature_dim(self, sampling_rate: int) -> int:
return self.config.n_mels
def extract(
self,
samples: Union[np.ndarray, torch.Tensor],
sampling_rate: int,
) -> Union[np.ndarray, torch.Tensor]:
# Check for sampling rate compatibility.
expected_sr = self.config.sampling_rate
assert sampling_rate == expected_sr, (
f"Mismatched sampling rate: extractor expects {expected_sr}, "
f"got {sampling_rate}"
)
is_numpy = False
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
is_numpy = True
if len(samples.shape) == 1:
samples = samples.unsqueeze(0)
assert samples.ndim == 2, samples.shape
assert samples.shape[0] == 1, samples.shape
mel = self._feature_fn(samples).squeeze().t()
assert mel.ndim == 2, mel.shape
assert mel.shape[1] == self.config.n_mels, mel.shape
num_frames = compute_num_frames(
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
)
if mel.shape[0] > num_frames:
mel = mel[:num_frames]
elif mel.shape[0] < num_frames:
mel = mel.unsqueeze(0)
mel = torch.nn.functional.pad(
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
).squeeze(0)
if is_numpy:
return mel.cpu().numpy()
else:
return mel
@property
def frame_shift(self) -> Seconds:
return self.config.hop_length / self.config.sampling_rate

View File

@ -0,0 +1,209 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
This script loads checkpoints and averages them.
(1) Average ZipVoice models before distill:
python3 ./zipvoice/generate_averaged_model.py \
--epoch 11 \
--avg 4 \
--distill 0 \
--token-file data/tokens_emilia.txt \
--exp-dir ./zipvoice/exp_zipvoice
It will generate a file `epoch-11-avg-14.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-11-avg-4.pt")`.
(2) Average ZipVoice-Distill models (the first stage model):
python3 ./zipvoice/generate_averaged_model.py \
--iter 60000 \
--avg 7 \
--distill 1 \
--token-file data/tokens_emilia.txt \
--exp-dir ./zipvoice/exp_zipvoice_distill_1stage
"""
import argparse
from pathlib import Path
import torch
from model import get_distill_model, get_model
from tokenizer import TokenizerEmilia, TokenizerLibriTTS
from train_flow import add_model_arguments, get_params
from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=11,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=4,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' or --iter",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipvoice/exp_zipvoice",
help="The experiment dir",
)
parser.add_argument(
"--distill",
type=str2bool,
default=False,
help="Whether to use distill model. ",
)
parser.add_argument(
"--dataset",
type=str,
default="emilia",
choices=["emilia", "libritts"],
help="The used training dataset for the model to inference",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
if params.dataset == "emilia":
tokenizer = TokenizerEmilia(
token_file=params.token_file, token_type=params.token_type
)
elif params.dataset == "libritts":
tokenizer = TokenizerLibriTTS(
token_file=params.token_file, token_type=params.token_type
)
params.vocab_size = tokenizer.vocab_size
params.pad_id = tokenizer.pad_id
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
print("Script started")
params.device = torch.device("cpu")
print(f"Device: {params.device}")
print("About to create model")
if params.distill:
model = get_distill_model(params)
else:
model = get_model(params)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
print(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(params.device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=params.device,
),
strict=True,
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
print(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(params.device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=params.device,
),
strict=True,
)
if params.iter > 0:
filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
else:
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save({"model": model.state_dict()}, filename)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
print("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,586 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
# Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads checkpoints to generate waveforms.
This script is supposed to be used with the model trained by yourself.
If you want to use the pre-trained checkpoints provided by us, please refer to zipvoice_infer.py.
(1) Usage with a pre-trained checkpoint:
(a) ZipVoice model before distill:
python3 zipvoice/infer.py \
--checkpoint zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
--distill 0 \
--token-file "data/tokens_emilia.txt" \
--test-list test.tsv \
--res-dir results/test \
--num-step 16 \
--guidance-scale 1
(b) ZipVoice-Distill:
python3 zipvoice/infer.py \
--checkpoint zipvoice/exp_zipvoice_distill/checkpoint-2000.pt \
--distill 1 \
--token-file "data/tokens_emilia.txt" \
--test-list test.tsv \
--res-dir results/test_distill \
--num-step 8 \
--guidance-scale 3
(2) Usage with a directory of checkpoints (may requires checkpoint averaging):
(a) ZipVoice model before distill:
python3 flow_match/infer.py \
--exp-dir zipvoice/exp_zipvoice \
--epoch 11 \
--avg 4 \
--distill 0 \
--token-file "data/tokens_emilia.txt" \
--test-list test.tsv \
--res-dir results \
--num-step 16 \
--guidance-scale 1
(b) ZipVoice-Distill:
python3 flow_match/infer.py \
--exp-dir zipvoice/exp_zipvoice_distill/ \
--iter 2000 \
--avg 0 \
--distill 1 \
--token-file "data/tokens_emilia.txt" \
--test-list test.tsv \
--res-dir results \
--num-step 8 \
--guidance-scale 3
"""
import argparse
import datetime as dt
import logging
import os
from pathlib import Path
from typing import Optional
import numpy as np
import soundfile as sf
import torch
import torch.nn as nn
import torchaudio
from checkpoint import load_checkpoint
from feature import TorchAudioFbank, TorchAudioFbankConfig
from lhotse.utils import fix_random_seed
from model import get_distill_model, get_model
from tokenizer import TokenizerEmilia, TokenizerLibriTTS
from train_flow import add_model_arguments, get_params
from vocos import Vocos
from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
default=None,
help="The checkpoint for inference. "
"If it is None, will use checkpoints under exp_dir",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipvoice/exp",
help="The experiment dir",
)
parser.add_argument(
"--epoch",
type=int,
default=0,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=4,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' or '--iter', avg=0 means no avg",
)
parser.add_argument(
"--vocoder-path",
type=str,
default=None,
help="The local vocos vocoder path, downloaded from huggingface, "
"will download the vocodoer from huggingface if it is None.",
)
parser.add_argument(
"--distill",
type=str2bool,
default=False,
help="Whether it is a distilled TTS model.",
)
parser.add_argument(
"--test-list",
type=str,
default=None,
help="The list of prompt speech, prompt_transcription, "
"and text to synthesize in the format of "
"'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
)
parser.add_argument(
"--res-dir",
type=str,
default="results",
help="Path name of the generated wavs dir",
)
parser.add_argument(
"--dataset",
type=str,
default="emilia",
choices=["emilia", "libritts"],
help="The used training dataset for the model to inference",
)
parser.add_argument(
"--guidance-scale",
type=float,
default=1.0,
help="The scale of classifier-free guidance during inference.",
)
parser.add_argument(
"--num-step",
type=int,
default=16,
help="The number of sampling steps.",
)
parser.add_argument(
"--feat-scale",
type=float,
default=0.1,
help="The scale factor of fbank feature",
)
parser.add_argument(
"--speed",
type=float,
default=1.0,
help="Control speech speed, 1.0 means normal, >1.0 means speed up",
)
parser.add_argument(
"--t-shift",
type=float,
default=0.5,
help="Shift t to smaller ones if t_shift < 1.0",
)
parser.add_argument(
"--target-rms",
type=float,
default=0.1,
help="Target speech normalization rms value",
)
parser.add_argument(
"--seed",
type=int,
default=666,
help="Random seed",
)
add_model_arguments(parser)
return parser
def get_vocoder(vocos_local_path: Optional[str] = None):
if vocos_local_path:
vocos_local_path = "model/huggingface/vocos-mel-24khz/"
vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(
f"{vocos_local_path}/pytorch_model.bin",
weights_only=True,
map_location="cpu",
)
vocoder.load_state_dict(state_dict)
else:
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
return vocoder
def generate_sentence(
save_path: str,
prompt_text: str,
prompt_wav: str,
text: str,
model: nn.Module,
vocoder: nn.Module,
tokenizer: TokenizerEmilia,
feature_extractor: TorchAudioFbank,
device: torch.device,
num_step: int = 16,
guidance_scale: float = 1.0,
speed: float = 1.0,
t_shift: float = 0.5,
target_rms: float = 0.1,
feat_scale: float = 0.1,
sampling_rate: int = 24000,
):
"""
Generate waveform of a text based on a given prompt
waveform and its transcription.
Args:
save_path (str): Path to save the generated wav.
prompt_text (str): Transcription of the prompt wav.
prompt_wav (str): Path to the prompt wav file.
text (str): Text to be synthesized into a waveform.
model (nn.Module): The model used for generation.
vocoder (nn.Module): The vocoder used to convert features to waveforms.
tokenizer (TokenizerEmilia): The tokenizer used to convert text to tokens.
feature_extractor (TorchAudioFbank): The feature extractor used to
extract acoustic features.
device (torch.device): The device on which computations are performed.
num_step (int, optional): Number of steps for decoding. Defaults to 16.
guidance_scale (float, optional): Scale for classifier-free guidance.
Defaults to 1.0.
speed (float, optional): Speed control. Defaults to 1.0.
t_shift (float, optional): Time shift. Defaults to 0.5.
target_rms (float, optional): Target RMS for waveform normalization.
Defaults to 0.1.
feat_scale (float, optional): Scale for features.
Defaults to 0.1.
sampling_rate (int, optional): Sampling rate for the waveform.
Defaults to 24000.
Returns:
metrics (dict): Dictionary containing time and real-time
factor metrics for processing.
"""
# Convert text to tokens
tokens = tokenizer.texts_to_token_ids([text])
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
# Load and preprocess prompt wav
prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
if prompt_rms < target_rms:
prompt_wav = prompt_wav * target_rms / prompt_rms
if prompt_sampling_rate != sampling_rate:
resampler = torchaudio.transforms.Resample(
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
)
prompt_wav = resampler(prompt_wav)
# Extract features from prompt wav
prompt_features = feature_extractor.extract(
prompt_wav, sampling_rate=sampling_rate
).to(device)
prompt_features = prompt_features.unsqueeze(0) * feat_scale
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
# Start timing
start_t = dt.datetime.now()
# Generate features
(
pred_features,
pred_features_lens,
pred_prompt_features,
pred_prompt_features_lens,
) = model.sample(
tokens=tokens,
prompt_tokens=prompt_tokens,
prompt_features=prompt_features,
prompt_features_lens=prompt_features_lens,
speed=speed,
t_shift=t_shift,
duration="predict",
num_step=num_step,
guidance_scale=guidance_scale,
)
# Postprocess predicted features
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
# Start vocoder processing
start_vocoder_t = dt.datetime.now()
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
# Calculate processing times and real-time factors
t = (dt.datetime.now() - start_t).total_seconds()
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
wav_seconds = wav.shape[-1] / sampling_rate
rtf = t / wav_seconds
rtf_no_vocoder = t_no_vocoder / wav_seconds
rtf_vocoder = t_vocoder / wav_seconds
metrics = {
"t": t,
"t_no_vocoder": t_no_vocoder,
"t_vocoder": t_vocoder,
"wav_seconds": wav_seconds,
"rtf": rtf,
"rtf_no_vocoder": rtf_no_vocoder,
"rtf_vocoder": rtf_vocoder,
}
# Adjust wav volume if necessary
if prompt_rms < target_rms:
wav = wav * prompt_rms / target_rms
wav = wav[0].cpu().numpy()
sf.write(save_path, wav, sampling_rate)
return metrics
def generate(
params: AttributeDict,
model: nn.Module,
vocoder: nn.Module,
tokenizer: TokenizerEmilia,
):
total_t = []
total_t_no_vocoder = []
total_t_vocoder = []
total_wav_seconds = []
config = TorchAudioFbankConfig(
sampling_rate=params.sampling_rate,
n_mels=100,
n_fft=1024,
hop_length=256,
)
feature_extractor = TorchAudioFbank(config)
with open(params.test_list, "r") as fr:
lines = fr.readlines()
for i, line in enumerate(lines):
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
save_path = f"{params.wav_dir}/{wav_name}.wav"
metrics = generate_sentence(
save_path=save_path,
prompt_text=prompt_text,
prompt_wav=prompt_wav,
text=text,
model=model,
vocoder=vocoder,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
device=params.device,
num_step=params.num_step,
guidance_scale=params.guidance_scale,
speed=params.speed,
t_shift=params.t_shift,
target_rms=params.target_rms,
feat_scale=params.feat_scale,
sampling_rate=params.sampling_rate,
)
print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
total_t.append(metrics["t"])
total_t_no_vocoder.append(metrics["t_no_vocoder"])
total_t_vocoder.append(metrics["t_vocoder"])
total_wav_seconds.append(metrics["wav_seconds"])
print(f"Average RTF: " f"{np.sum(total_t)/np.sum(total_wav_seconds):.4f}")
print(
f"Average RTF w/o vocoder: "
f"{np.sum(total_t_no_vocoder)/np.sum(total_wav_seconds):.4f}"
)
print(
f"Average RTF vocoder: "
f"{np.sum(total_t_vocoder)/np.sum(total_wav_seconds):.4f}"
)
@torch.inference_mode()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
if params.iter > 0:
params.suffix = (
f"wavs-iter-{params.iter}-avg"
f"-{params.avg}-step-{params.num_step}-scale-{params.guidance_scale}"
)
elif params.epoch > 0:
params.suffix = (
f"wavs-epoch-{params.epoch}-avg"
f"-{params.avg}-step-{params.num_step}-scale-{params.guidance_scale}"
)
else:
params.suffix = "wavs"
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
logging.info("Decoding started")
if torch.cuda.is_available():
params.device = torch.device("cuda", 0)
else:
params.device = torch.device("cpu")
logging.info(f"Device: {params.device}")
if params.dataset == "emilia":
tokenizer = TokenizerEmilia(
token_file=params.token_file, token_type=params.token_type
)
elif params.dataset == "libritts":
tokenizer = TokenizerLibriTTS(
token_file=params.token_file, token_type=params.token_type
)
params.vocab_size = tokenizer.vocab_size
params.pad_id = tokenizer.pad_id
logging.info(params)
fix_random_seed(params.seed)
logging.info("About to create model")
if params.distill:
model = get_distill_model(params)
else:
model = get_model(params)
if params.checkpoint:
load_checkpoint(params.checkpoint, model, strict=True)
else:
if params.avg == 0:
if params.iter > 0:
load_checkpoint(
f"{params.exp_dir}/checkpoint-{params.iter}.pt", model, strict=True
)
else:
load_checkpoint(
f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True
)
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(params.device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=params.device,
),
strict=True,
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(params.device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=params.device,
),
strict=True,
)
model = model.to(params.device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
vocoder = get_vocoder(params.vocoder_path)
vocoder = vocoder.to(params.device)
vocoder.eval()
num_param = sum([p.numel() for p in vocoder.parameters()])
logging.info(f"Number of vocoder parameters: {num_param}")
params.wav_dir = f"{params.res_dir}/{params.suffix}"
os.makedirs(params.wav_dir, exist_ok=True)
assert (
params.test_list is not None
), "Please provide --test-list for speech synthesize."
generate(
params=params,
model=model,
vocoder=vocoder,
tokenizer=tokenizer,
)
logging.info("Done!")
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -0,0 +1,578 @@
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
# Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
import torch.nn as nn
from scaling import ScheduledFloat
from solver import EulerSolver
from torch.nn.parallel import DistributedDataParallel as DDP
from utils import (
AttributeDict,
condition_time_mask,
get_tokens_index,
make_pad_mask,
pad_labels,
prepare_avg_tokens_durations,
to_int_tuple,
)
from zipformer import TTSZipformer
def get_model(params: AttributeDict) -> nn.Module:
"""Get the normal TTS model."""
fm_decoder = get_fm_decoder_model(params)
text_encoder = get_text_encoder_model(params)
model = TtsModel(
fm_decoder=fm_decoder,
text_encoder=text_encoder,
text_embed_dim=params.text_embed_dim,
feat_dim=params.feat_dim,
vocab_size=params.vocab_size,
pad_id=params.pad_id,
)
return model
def get_distill_model(params: AttributeDict) -> nn.Module:
"""Get the distillation TTS model."""
fm_decoder = get_fm_decoder_model(params, distill=True)
text_encoder = get_text_encoder_model(params)
model = DistillTTSModelTrainWrapper(
fm_decoder=fm_decoder,
text_encoder=text_encoder,
text_embed_dim=params.text_embed_dim,
feat_dim=params.feat_dim,
vocab_size=params.vocab_size,
pad_id=params.pad_id,
)
return model
def get_fm_decoder_model(params: AttributeDict, distill: bool = False) -> nn.Module:
"""Get the Zipformer-based FM decoder model."""
encoder = TTSZipformer(
in_dim=params.feat_dim * 3,
out_dim=params.feat_dim,
downsampling_factor=to_int_tuple(params.fm_decoder_downsampling_factor),
num_encoder_layers=to_int_tuple(params.fm_decoder_num_layers),
cnn_module_kernel=to_int_tuple(params.fm_decoder_cnn_module_kernel),
encoder_dim=params.fm_decoder_dim,
feedforward_dim=params.fm_decoder_feedforward_dim,
num_heads=params.fm_decoder_num_heads,
query_head_dim=params.query_head_dim,
pos_head_dim=params.pos_head_dim,
value_head_dim=params.value_head_dim,
pos_dim=params.pos_dim,
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
use_time_embed=True,
time_embed_dim=192,
use_guidance_scale_embed=distill,
)
return encoder
def get_text_encoder_model(params: AttributeDict) -> nn.Module:
"""Get the Zipformer-based text encoder model."""
encoder = TTSZipformer(
in_dim=params.text_embed_dim,
out_dim=params.feat_dim,
downsampling_factor=to_int_tuple(params.text_encoder_downsampling_factor),
num_encoder_layers=to_int_tuple(params.text_encoder_num_layers),
cnn_module_kernel=to_int_tuple(params.text_encoder_cnn_module_kernel),
encoder_dim=params.text_encoder_dim,
feedforward_dim=params.text_encoder_feedforward_dim,
num_heads=params.text_encoder_num_heads,
query_head_dim=params.query_head_dim,
pos_head_dim=params.pos_head_dim,
value_head_dim=params.value_head_dim,
pos_dim=params.pos_dim,
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
use_time_embed=False,
)
return encoder
class TtsModel(nn.Module):
"""The normal TTS model."""
def __init__(
self,
fm_decoder: nn.Module,
text_encoder: nn.Module,
text_embed_dim: int,
feat_dim: int,
vocab_size: int,
pad_id: int = 0,
):
"""
Args:
fm_decoder: the flow-matching encoder model, inputs are the
input condition embeddings and noisy acoustic features,
outputs are better acoustic features.
text_encoder: the text encoder model. input are text
embeddings, output are contextualized text embeddings.
text_embed_dim: dimension of text embedding.
feat_dim: dimension of acoustic features.
vocab_size: vocabulary size.
pad_id: padding id.
"""
super().__init__()
self.feat_dim = feat_dim
self.text_embed_dim = text_embed_dim
self.pad_id = pad_id
self.fm_decoder = fm_decoder
self.text_encoder = text_encoder
self.embed = nn.Embedding(vocab_size, text_embed_dim)
self.distill = False
def forward_fm_decoder(
self,
t: torch.Tensor,
xt: torch.Tensor,
text_condition: torch.Tensor,
speech_condition: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
guidance_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute velocity.
Args:
t: A tensor of shape (N, 1, 1) or a tensor of a float,
in the range of (0, 1).
xt: the input of the current timestep, including condition
embeddings and noisy acoustic features.
text_condition: the text condition embeddings, with the
shape (batch, seq_len, emb_dim).
speech_condition: the speech condition embeddings, with the
shape (batch, seq_len, emb_dim).
padding_mask: The mask for padding, True means masked
position, with the shape (N, T).
guidance_scale: The guidance scale in classifier-free guidance,
which is a tensor of shape (N, 1, 1) or a tensor of a float.
Returns:
predicted velocity, with the shape (batch, seq_len, emb_dim).
"""
assert t.dim() in (0, 3)
# Handle t with the shape (N, 1, 1):
# squeeze the last dimension if it's size is 1.
while t.dim() > 1 and t.size(-1) == 1:
t = t.squeeze(-1)
if guidance_scale is not None:
while guidance_scale.dim() > 1 and guidance_scale.size(-1) == 1:
guidance_scale = guidance_scale.squeeze(-1)
# Handle t with a single value: expand to the size of batch size.
if t.dim() == 0:
t = t.repeat(xt.shape[0])
if guidance_scale is not None and guidance_scale.dim() == 0:
guidance_scale = guidance_scale.repeat(xt.shape[0])
xt = torch.cat([xt, text_condition, speech_condition], dim=2)
vt = self.fm_decoder(
x=xt, t=t, padding_mask=padding_mask, guidance_scale=guidance_scale
)
return vt
def forward_text_embed(
self,
tokens: List[List[int]],
):
"""
Get the text embeddings.
Args:
tokens: a list of list of token ids.
Returns:
embed: the text embeddings, shape (batch, seq_len, emb_dim).
tokens_lens: the length of each token sequence, shape (batch,).
"""
device = (
self.device if isinstance(self, DDP) else next(self.parameters()).device
)
tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S)
embed = self.embed(tokens_padded) # (B, S, C)
tokens_lens = torch.tensor(
[len(token) for token in tokens], dtype=torch.int64, device=device
)
tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S)
embed = self.text_encoder(
x=embed, t=None, padding_mask=tokens_padding_mask
) # (B, S, C)
return embed, tokens_lens
def forward_text_condition(
self,
embed: torch.Tensor,
tokens_lens: torch.Tensor,
features_lens: torch.Tensor,
):
"""
Get the text condition with the same length of the acoustic feature.
Args:
embed: the text embeddings, shape (batch, token_seq_len, emb_dim).
tokens_lens: the length of each token sequence, shape (batch,).
features_lens: the length of each acoustic feature sequence,
shape (batch,).
Returns:
text_condition: the text condition, shape
(batch, feature_seq_len, emb_dim).
padding_mask: the padding mask of text condition, shape
(batch, feature_seq_len).
"""
num_frames = int(features_lens.max())
padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T)
tokens_durations = prepare_avg_tokens_durations(features_lens, tokens_lens)
tokens_index = get_tokens_index(tokens_durations, num_frames).to(
embed.device
) # (B, T)
text_condition = torch.gather(
embed,
dim=1,
index=tokens_index.unsqueeze(-1).expand(
embed.size(0), num_frames, embed.size(-1)
),
) # (B, T, F)
return text_condition, padding_mask
def forward_text_train(
self,
tokens: List[List[int]],
features_lens: torch.Tensor,
):
"""
Process text for training, given text tokens and real feature lengths.
"""
embed, tokens_lens = self.forward_text_embed(tokens)
text_condition, padding_mask = self.forward_text_condition(
embed, tokens_lens, features_lens
)
return (
text_condition,
padding_mask,
)
def forward_text_inference_gt_duration(
self,
tokens: List[List[int]],
features_lens: torch.Tensor,
prompt_tokens: List[List[int]],
prompt_features_lens: torch.Tensor,
):
"""
Process text for inference, given text tokens, real feature lengths and prompts.
"""
tokens = [
prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
]
features_lens = prompt_features_lens + features_lens
embed, tokens_lens = self.forward_text_embed(tokens)
text_condition, padding_mask = self.forward_text_condition(
embed, tokens_lens, features_lens
)
return text_condition, padding_mask
def forward_text_inference_ratio_duration(
self,
tokens: List[List[int]],
prompt_tokens: List[List[int]],
prompt_features_lens: torch.Tensor,
speed: float,
):
"""
Process text for inference, given text tokens and prompts,
feature lengths are predicted with the ratio of token numbers.
"""
device = (
self.device if isinstance(self, DDP) else next(self.parameters()).device
)
cat_tokens = [
prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
]
prompt_tokens_lens = torch.tensor(
[len(token) for token in prompt_tokens], dtype=torch.int64, device=device
)
cat_embed, cat_tokens_lens = self.forward_text_embed(cat_tokens)
features_lens = torch.ceil(
(prompt_features_lens / prompt_tokens_lens * cat_tokens_lens / speed)
).to(dtype=torch.int64)
text_condition, padding_mask = self.forward_text_condition(
cat_embed, cat_tokens_lens, features_lens
)
return text_condition, padding_mask
def forward(
self,
tokens: List[List[int]],
features: torch.Tensor,
features_lens: torch.Tensor,
noise: torch.Tensor,
t: torch.Tensor,
condition_drop_ratio: float = 0.0,
) -> torch.Tensor:
"""Forward pass of the model for training.
Args:
tokens: a list of list of token ids.
features: the acoustic features, with the shape (batch, seq_len, feat_dim).
features_lens: the length of each acoustic feature sequence, shape (batch,).
noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
t: the time step, with the shape (batch, 1, 1).
condition_drop_ratio: the ratio of dropped text condition.
Returns:
fm_loss: the flow-matching loss.
"""
(text_condition, padding_mask,) = self.forward_text_train(
tokens=tokens,
features_lens=features_lens,
)
speech_condition_mask = condition_time_mask(
features_lens=features_lens,
mask_percent=(0.7, 1.0),
max_len=features.size(1),
)
speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
if condition_drop_ratio > 0.0:
drop_mask = (
torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
> condition_drop_ratio
)
text_condition = text_condition * drop_mask
xt = features * t + noise * (1 - t)
ut = features - noise # (B, T, F)
vt = self.forward_fm_decoder(
t=t,
xt=xt,
text_condition=text_condition,
speech_condition=speech_condition,
padding_mask=padding_mask,
)
loss_mask = speech_condition_mask & (~padding_mask)
fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
return fm_loss
def sample(
self,
tokens: List[List[int]],
prompt_tokens: List[List[int]],
prompt_features: torch.Tensor,
prompt_features_lens: torch.Tensor,
features_lens: Optional[torch.Tensor] = None,
speed: float = 1.0,
t_shift: float = 1.0,
duration: str = "predict",
num_step: int = 5,
guidance_scale: float = 0.5,
) -> torch.Tensor:
"""
Generate acoustic features, given text tokens, prompts feature
and prompt transcription's text tokens.
Args:
tokens: a list of list of text tokens.
prompt_tokens: a list of list of prompt tokens.
prompt_features: the prompt feature with the shape
(batch_size, seq_len, feat_dim).
prompt_features_lens: the length of each prompt feature,
with the shape (batch_size,).
features_lens: the length of the predicted eature, with the
shape (batch_size,). It is used only when duration is "real".
duration: "real" or "predict". If "real", the predicted
feature length is given by features_lens.
num_step: the number of steps to use in the ODE solver.
guidance_scale: the guidance scale for classifier-free guidance.
distill: whether to use the distillation model for sampling.
"""
assert duration in ["real", "predict"]
if duration == "predict":
(
text_condition,
padding_mask,
) = self.forward_text_inference_ratio_duration(
tokens=tokens,
prompt_tokens=prompt_tokens,
prompt_features_lens=prompt_features_lens,
speed=speed,
)
else:
assert features_lens is not None
text_condition, padding_mask = self.forward_text_inference_gt_duration(
tokens=tokens,
features_lens=features_lens,
prompt_tokens=prompt_tokens,
prompt_features_lens=prompt_features_lens,
)
batch_size, num_frames, _ = text_condition.shape
speech_condition = torch.nn.functional.pad(
prompt_features, (0, 0, 0, num_frames - prompt_features.size(1))
) # (B, T, F)
# False means speech condition positions.
speech_condition_mask = make_pad_mask(prompt_features_lens, num_frames)
speech_condition = torch.where(
speech_condition_mask.unsqueeze(-1),
torch.zeros_like(speech_condition),
speech_condition,
)
x0 = torch.randn(
batch_size, num_frames, self.feat_dim, device=text_condition.device
)
solver = EulerSolver(self, distill=self.distill, func_name="forward_fm_decoder")
x1 = solver.sample(
x=x0,
text_condition=text_condition,
speech_condition=speech_condition,
padding_mask=padding_mask,
num_step=num_step,
guidance_scale=guidance_scale,
t_shift=t_shift,
)
x1_wo_prompt_lens = (~padding_mask).sum(-1) - prompt_features_lens
x1_prompt = torch.zeros(
x1.size(0), prompt_features_lens.max(), x1.size(2), device=x1.device
)
x1_wo_prompt = torch.zeros(
x1.size(0), x1_wo_prompt_lens.max(), x1.size(2), device=x1.device
)
for i in range(x1.size(0)):
x1_wo_prompt[i, : x1_wo_prompt_lens[i], :] = x1[
i,
prompt_features_lens[i] : prompt_features_lens[i]
+ x1_wo_prompt_lens[i],
]
x1_prompt[i, : prompt_features_lens[i], :] = x1[
i, : prompt_features_lens[i]
]
return x1_wo_prompt, x1_wo_prompt_lens, x1_prompt, prompt_features_lens
def sample_intermediate(
self,
tokens: List[List[int]],
features: torch.Tensor,
features_lens: torch.Tensor,
noise: torch.Tensor,
speech_condition_mask: torch.Tensor,
t_start: torch.Tensor,
t_end: torch.Tensor,
num_step: int = 1,
guidance_scale: torch.Tensor = None,
) -> torch.Tensor:
"""
Generate acoustic features in intermediate timesteps.
Args:
tokens: List of list of token ids.
features: The acoustic features, with the shape (batch, seq_len, feat_dim).
features_lens: The length of each acoustic feature sequence,
with the shape (batch,).
noise: The initial noise, with the shape (batch, seq_len, feat_dim).
speech_condition_mask: The mask for speech condition, True means
non-condition positions, with the shape (batch, seq_len).
t_start: The start timestep, with the shape (batch, 1, 1).
t_end: The end timestep, with the shape (batch, 1, 1).
num_step: The number of steps for sampling.
guidance_scale: The scale for classifier-free guidance inference,
with the shape (batch, 1, 1).
distill: Whether to use distillation model.
"""
(text_condition, padding_mask,) = self.forward_text_train(
tokens=tokens,
features_lens=features_lens,
)
speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
solver = EulerSolver(self, distill=self.distill, func_name="forward_fm_decoder")
x_t_end = solver.sample(
x=noise,
text_condition=text_condition,
speech_condition=speech_condition,
padding_mask=padding_mask,
num_step=num_step,
guidance_scale=guidance_scale,
t_start=t_start,
t_end=t_end,
)
x_t_end_lens = (~padding_mask).sum(-1)
return x_t_end, x_t_end_lens
class DistillTTSModelTrainWrapper(TtsModel):
"""Wrapper for training the distilled TTS model."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.distill = True
def forward(
self,
tokens: List[List[int]],
features: torch.Tensor,
features_lens: torch.Tensor,
noise: torch.Tensor,
speech_condition_mask: torch.Tensor,
t_start: torch.Tensor,
t_end: torch.Tensor,
num_step: int = 1,
guidance_scale: torch.Tensor = None,
) -> torch.Tensor:
return self.sample_intermediate(
tokens=tokens,
features=features,
features_lens=features_lens,
noise=noise,
speech_condition_mask=speech_condition_mask,
t_start=t_start,
t_end=t_end,
num_step=num_step,
guidance_scale=guidance_scale,
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,277 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import torch
class DiffusionModel(torch.nn.Module):
"""A wrapper of diffusion models for inference.
Args:
model: The diffusion model.
distill: Whether it is a distillation model.
"""
def __init__(
self,
model: torch.nn.Module,
distill: bool = False,
func_name: str = "forward_fm_decoder",
):
super().__init__()
self.model = model
self.distill = distill
self.func_name = func_name
self.model_func = getattr(self.model, func_name)
def forward(
self,
t: torch.Tensor,
x: torch.Tensor,
text_condition: torch.Tensor,
speech_condition: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
guidance_scale: Union[float, torch.Tensor] = 0.0,
**kwargs
) -> torch.Tensor:
"""
Forward function that Handles the classifier-free guidance.
Args:
t: The current timestep, a tensor of shape (batch, 1, 1) or a tensor of a single float.
x: The initial value, with the shape (batch, seq_len, emb_dim).
text_condition: The text_condition of the diffision model, with the shape (batch, seq_len, emb_dim).
speech_condition: The speech_condition of the diffision model, with the shape (batch, seq_len, emb_dim).
padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len).
guidance_scale: The scale of classifier-free guidance, a float or a tensor of shape (batch, 1, 1).
Retrun:
The prediction with the shape (batch, seq_len, emb_dim).
"""
if not torch.is_tensor(guidance_scale):
guidance_scale = torch.tensor(
guidance_scale, dtype=t.dtype, device=t.device
)
if self.distill:
return self.model_func(
t=t,
xt=x,
text_condition=text_condition,
speech_condition=speech_condition,
padding_mask=padding_mask,
guidance_scale=guidance_scale,
**kwargs
)
if (guidance_scale == 0.0).all():
return self.model_func(
t=t,
xt=x,
text_condition=text_condition,
speech_condition=speech_condition,
padding_mask=padding_mask,
**kwargs
)
else:
if t.dim() != 0:
t = torch.cat([t] * 2, dim=0)
x = torch.cat([x] * 2, dim=0)
padding_mask = torch.cat([padding_mask] * 2, dim=0)
text_condition = torch.cat(
[torch.zeros_like(text_condition), text_condition], dim=0
)
if t.dim() == 0:
if t > 0.5:
speech_condition = torch.cat(
[torch.zeros_like(speech_condition), speech_condition], dim=0
)
else:
guidance_scale = guidance_scale * 2
speech_condition = torch.cat(
[speech_condition, speech_condition], dim=0
)
else:
assert t.dim() > 0, t
larger_t_index = (t > 0.5).squeeze(1).squeeze(1)
zero_speech_condition = torch.cat(
[torch.zeros_like(speech_condition), speech_condition], dim=0
)
speech_condition = torch.cat(
[speech_condition, speech_condition], dim=0
)
speech_condition[larger_t_index] = zero_speech_condition[larger_t_index]
guidance_scale[~larger_t_index[: larger_t_index.size(0) // 2]] *= 2
data_uncond, data_cond = self.model_func(
t=t,
xt=x,
text_condition=text_condition,
speech_condition=speech_condition,
padding_mask=padding_mask,
**kwargs
).chunk(2, dim=0)
res = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond
return res
class EulerSolver:
def __init__(
self,
model: torch.nn.Module,
distill: bool = False,
func_name: str = "forward_fm_decoder",
):
"""Construct a Euler Solver
Args:
model: The diffusion model.
distill: Whether it is distillation model.
"""
self.model = DiffusionModel(model, distill=distill, func_name=func_name)
def sample(
self,
x: torch.Tensor,
text_condition: torch.Tensor,
speech_condition: torch.Tensor,
padding_mask: torch.Tensor,
num_step: int = 10,
guidance_scale: Union[float, torch.Tensor] = 0.0,
t_start: Union[float, torch.Tensor] = 0.0,
t_end: Union[float, torch.Tensor] = 1.0,
t_shift: float = 1.0,
**kwargs
) -> torch.Tensor:
"""
Compute the sample at time `t_end` by Euler Solver.
Args:
x: The initial value at time `t_start`, with the shape (batch, seq_len, emb_dim).
text_condition: The text condition of the diffision mode, with the shape (batch, seq_len, emb_dim).
speech_condition: The speech condition of the diffision model, with the shape (batch, seq_len, emb_dim).
padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len).
num_step: The number of ODE steps.
guidance_scale: The scale for classifier-free guidance, which is
a float or a tensor with the shape (batch, 1, 1).
t_start: the start timestep in the range of [0, 1],
which is a float or a tensor with the shape (batch, 1, 1).
t_end: the end time_step in the range of [0, 1],
which is a float or a tensor with the shape (batch, 1, 1).
t_shift: shift the t toward smaller numbers so that the sampling
will emphasize low SNR region. Should be in the range of (0, 1].
The shifting will be more significant when the number is smaller.
Returns:
The approximated solution at time `t_end`.
"""
device = x.device
if torch.is_tensor(t_start) and t_start.dim() > 0:
timesteps = get_time_steps_batch(
t_start=t_start,
t_end=t_end,
num_step=num_step,
t_shift=t_shift,
device=device,
)
else:
timesteps = get_time_steps(
t_start=t_start,
t_end=t_end,
num_step=num_step,
t_shift=t_shift,
device=device,
)
for step in range(num_step):
v = self.model(
t=timesteps[step],
x=x,
text_condition=text_condition,
speech_condition=speech_condition,
padding_mask=padding_mask,
guidance_scale=guidance_scale,
**kwargs
)
x = x + v * (timesteps[step + 1] - timesteps[step])
return x
def get_time_steps(
t_start: float = 0.0,
t_end: float = 1.0,
num_step: int = 10,
t_shift: float = 1.0,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Compute the intermediate time steps for sampling.
Args:
t_start: The starting time of the sampling (default is 0).
t_end: The starting time of the sampling (default is 1).
num_step: The number of sampling.
t_shift: shift the t toward smaller numbers so that the sampling
will emphasize low SNR region. Should be in the range of (0, 1].
The shifting will be more significant when the number is smaller.
device: A torch device.
Returns:
The time step with the shape (num_step + 1,).
"""
timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device)
timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
return timesteps
def get_time_steps_batch(
t_start: torch.Tensor,
t_end: torch.Tensor,
num_step: int = 10,
t_shift: float = 1.0,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Compute the intermediate time steps for sampling in the batch mode.
Args:
t_start: The starting time of the sampling (default is 0), with the shape (batch, 1, 1).
t_end: The starting time of the sampling (default is 1), with the shape (batch, 1, 1).
num_step: The number of sampling.
t_shift: shift the t toward smaller numbers so that the sampling
will emphasize low SNR region. Should be in the range of (0, 1].
The shifting will be more significant when the number is smaller.
device: A torch device.
Returns:
The time step with the shape (num_step + 1, N, 1, 1).
"""
while t_start.dim() > 1 and t_start.size(-1) == 1:
t_start = t_start.squeeze(-1)
while t_end.dim() > 1 and t_end.size(-1) == 1:
t_end = t_end.squeeze(-1)
assert t_start.dim() == t_end.dim() == 1
timesteps_shape = (num_step + 1, t_start.size(0))
timesteps = torch.zeros(timesteps_shape, device=device)
for i in range(t_start.size(0)):
timesteps[:, i] = torch.linspace(t_start[i], t_end[i], steps=num_step + 1)
timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
return timesteps.unsqueeze(-1).unsqueeze(-1)

View File

@ -0,0 +1,570 @@
# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao
# Han Zhu,
# Wei Kang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
import unicodedata
from functools import reduce
from typing import Dict, List, Optional
import cn2an
import inflect
import jieba
from pypinyin import Style, lazy_pinyin
from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
try:
from piper_phonemize import phonemize_espeak
except Exception as ex:
raise RuntimeError(
f"{ex}\nPlease run\n"
"pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html"
)
_inflect = inflect.engine()
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r"[0-9]+")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [
(re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
("etc", "et cetera"),
("btw", "by the way"),
]
]
def intersperse(sequence, item=0):
result = [item] * (len(sequence) * 2 + 1)
result[1::2] = sequence
return result
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def _remove_commas(m):
return m.group(1).replace(",", "")
def _expand_decimal_point(m):
return m.group(1).replace(".", " point ")
def _expand_percent(m):
return m.group(1).replace("%", " percent ")
def _expand_dollars(m):
match = m.group(1)
parts = match.split(".")
if len(parts) > 2:
return " " + match + " dollars " # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = "cent" if cents == 1 else "cents"
return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = "dollar" if dollars == 1 else "dollars"
return " %s %s " % (dollars, dollar_unit)
elif cents:
cent_unit = "cent" if cents == 1 else "cents"
return " %s %s " % (cents, cent_unit)
else:
return " zero dollars "
def fraction_to_words(numerator, denominator):
if numerator == 1 and denominator == 2:
return " one half "
if numerator == 1 and denominator == 4:
return " one quarter "
if denominator == 2:
return " " + _inflect.number_to_words(numerator) + " halves "
if denominator == 4:
return " " + _inflect.number_to_words(numerator) + " quarters "
return (
" "
+ _inflect.number_to_words(numerator)
+ " "
+ _inflect.ordinal(_inflect.number_to_words(denominator))
+ " "
)
def _expand_fraction(m):
numerator = int(m.group(1))
denominator = int(m.group(2))
return fraction_to_words(numerator, denominator)
def _expand_ordinal(m):
return " " + _inflect.number_to_words(m.group(0)) + " "
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return " two thousand "
elif num > 2000 and num < 2010:
return " two thousand " + _inflect.number_to_words(num % 100) + " "
elif num % 100 == 0:
return " " + _inflect.number_to_words(num // 100) + " hundred "
else:
return (
" "
+ _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(
", ", " "
)
+ " "
)
else:
return " " + _inflect.number_to_words(num, andword="") + " "
# Normalize numbers pronunciation
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r"\1 pounds", text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_fraction_re, _expand_fraction, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_percent_number_re, _expand_percent, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
# Convert numbers to Chinese pronunciation
def number_to_chinese(text):
text = cn2an.transform(text, "an2cn")
return text
def map_punctuations(text):
text = text.replace("", ",")
text = text.replace("", ".")
text = text.replace("", "!")
text = text.replace("", "?")
text = text.replace("", ";")
text = text.replace("", ":")
text = text.replace("", ",")
text = text.replace("", "'")
text = text.replace("", '"')
text = text.replace("", '"')
text = text.replace("", "'")
text = text.replace("", "")
text = text.replace("···", "")
text = text.replace("・・・", "")
text = text.replace("...", "")
return text
def is_chinese(char):
if char >= "\u4e00" and char <= "\u9fa5":
return True
else:
return False
def is_alphabet(char):
if (char >= "\u0041" and char <= "\u005a") or (
char >= "\u0061" and char <= "\u007a"
):
return True
else:
return False
def is_hangul(char):
letters = unicodedata.normalize("NFD", char)
return all(
["\u1100" <= c <= "\u11ff" or "\u3131" <= c <= "\u318e" for c in letters]
)
def is_japanese(char):
return any(
[
start <= char <= end
for start, end in [
("\u3041", "\u3096"),
("\u30a0", "\u30ff"),
("\uff5f", "\uff9f"),
("\u31f0", "\u31ff"),
("\u3220", "\u3243"),
("\u3280", "\u337f"),
]
]
)
def get_segment(text: str) -> List[str]:
# sentence --> [ch_part, en_part, ch_part, ...]
# example :
# input : 我们是小米人,是吗? Yes I think so!霍...啦啦啦
# output : [('我们是小米人,是吗? ', 'zh'), ('Yes I think so!', 'en'), ('霍...啦啦啦', 'zh')]
segments = []
types = []
flag = 0
temp_seg = ""
temp_lang = ""
for i, ch in enumerate(text):
if is_chinese(ch):
types.append("zh")
elif is_alphabet(ch):
types.append("en")
else:
types.append("other")
assert len(types) == len(text)
for i in range(len(types)):
# find the first char of the seg
if flag == 0:
temp_seg += text[i]
temp_lang = types[i]
flag = 1
else:
if temp_lang == "other":
if types[i] == temp_lang:
temp_seg += text[i]
else:
temp_seg += text[i]
temp_lang = types[i]
else:
if types[i] == temp_lang:
temp_seg += text[i]
elif types[i] == "other":
temp_seg += text[i]
else:
segments.append((temp_seg, temp_lang))
temp_seg = text[i]
temp_lang = types[i]
flag = 1
segments.append((temp_seg, temp_lang))
return segments
def preprocess(text: str) -> str:
text = map_punctuations(text)
return text
def tokenize_ZH(text: str) -> List[str]:
try:
text = number_to_chinese(text)
segs = list(jieba.cut(text))
full = lazy_pinyin(
segs, style=Style.TONE3, tone_sandhi=True, neutral_tone_with_five=True
)
phones = []
for x in full:
# valid pinyin (in tone3 style) is alphabet + 1 number in [1-5].
if not (x[0:-1].isalpha() and x[-1] in ("1", "2", "3", "4", "5")):
phones.append(x)
continue
initial = to_initials(x, strict=False)
# don't want to share tokens with espeak tokens, so use tone3 style
final = to_finals_tone3(x, strict=False, neutral_tone_with_five=True)
if initial != "":
# don't want to share tokens with espeak tokens, so add a '0' after each initial
phones.append(initial + "0")
if final != "":
phones.append(final)
return phones
except:
return []
def tokenize_EN(text: str) -> List[str]:
try:
text = expand_abbreviations(text)
text = normalize_numbers(text)
tokens = phonemize_espeak(text, "en-us")
tokens = reduce(lambda x, y: x + y, tokens)
return tokens
except:
return []
class TokenizerEmilia(object):
def __init__(self, token_file: Optional[str] = None, token_type="phone"):
"""
Args:
tokens: the file that contains information that maps tokens to ids,
which is a text file with '{token} {token_id}' per line.
"""
assert (
token_type == "phone"
), f"Only support phone tokenizer for Emilia, but get {token_type}."
self.has_tokens = False
if token_file is None:
logging.debug(
"Initialize Tokenizer without tokens file, will fail when map to ids."
)
return
self.token2id: Dict[str, int] = {}
with open(token_file, "r", encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split("\t")
token, id = info[0], int(info[1])
assert token not in self.token2id, token
self.token2id[token] = id
self.pad_id = self.token2id["_"] # padding
self.vocab_size = len(self.token2id)
self.has_tokens = True
def texts_to_token_ids(
self,
texts: List[str],
) -> List[List[int]]:
return self.tokens_to_token_ids(self.texts_to_tokens(texts))
def texts_to_tokens(
self,
texts: List[str],
) -> List[List[str]]:
"""
Args:
texts:
A list of transcripts.
Returns:
Return a list of a list of tokens [utterance][token]
"""
for i in range(len(texts)):
# Text normalization
texts[i] = preprocess(texts[i])
phoneme_list = []
for text in texts:
# now only en and ch
segments = get_segment(text)
all_phoneme = []
for index in range(len(segments)):
seg = segments[index]
if seg[1] == "zh":
phoneme = tokenize_ZH(seg[0])
else:
if seg[1] != "en":
logging.warning(
f"The lang should be en, given {seg[1]}, skipping segment : {seg}"
)
continue
phoneme = tokenize_EN(seg[0])
all_phoneme += phoneme
phoneme_list.append(all_phoneme)
return phoneme_list
def tokens_to_token_ids(
self,
tokens: List[List[str]],
intersperse_blank: bool = False,
) -> List[List[int]]:
"""
Args:
tokens_list:
A list of token list, each corresponding to one utterance.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
Returns:
Return a list of token id list [utterance][token_id]
"""
assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
token_ids = []
for tks in tokens:
ids = []
for t in tks:
if t not in self.token2id:
logging.warning(f"Skip OOV {t}")
continue
ids.append(self.token2id[t])
if intersperse_blank:
ids = intersperse(ids, self.pad_id)
token_ids.append(ids)
return token_ids
class TokenizerLibriTTS(object):
def __init__(self, token_file: str, token_type: str):
"""
Args:
type: the type of tokenizer, e.g., bpe, char, phone.
tokens: the file that contains information that maps tokens to ids,
which is a text file with '{token} {token_id}' per line if type is
char or phone, otherwise it is a bpe_model file.
"""
self.type = token_type
assert token_type in ["bpe", "char", "phone"]
# Parse token file
if token_type == "bpe":
import sentencepiece as spm
self.sp = spm.SentencePieceProcessor()
self.sp.load(token_file)
self.pad_id = self.sp.piece_to_id("<pad>")
self.vocab_size = self.sp.get_piece_size()
else:
self.token2id: Dict[str, int] = {}
with open(token_file, "r", encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split("\t")
token, id = info[0], int(info[1])
assert token not in self.token2id, token
self.token2id[token] = id
self.pad_id = self.token2id["_"] # padding
self.vocab_size = len(self.token2id)
try:
from tacotron_cleaner.cleaners import custom_english_cleaners as cleaner
except Exception as ex:
raise RuntimeError(
f"{ex}\nPlease run\n"
"pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/"
)
self.cleaner = cleaner
def texts_to_token_ids(
self,
texts: List[str],
lang: str = "en-us",
) -> List[List[int]]:
"""
Args:
texts:
A list of transcripts.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
Used when alignment is from MAS.
lang:
Language argument passed to phonemize_espeak().
Returns:
Return a list of token id list [utterance][token_id]
"""
for i in range(len(texts)):
# Text normalization
texts[i] = self.cleaner(texts[i])
if self.type == "bpe":
token_ids_list = self.sp.encode(texts)
elif self.type == "phone":
token_ids_list = []
for text in texts:
tokens_list = phonemize_espeak(text.lower(), lang)
tokens = []
for t in tokens_list:
tokens.extend(t)
token_ids = []
for t in tokens:
if t not in self.token2id:
logging.warning(f"Skip OOV {t}")
continue
token_ids.append(self.token2id[t])
token_ids_list.append(token_ids)
else:
token_ids_list = []
for text in texts:
token_ids = []
for t in text:
if t not in self.token2id:
logging.warning(f"Skip OOV {t}")
continue
token_ids.append(self.token2id[t])
token_ids_list.append(token_ids)
return token_ids_list
def tokens_to_token_ids(
self,
tokens_list: List[str],
) -> List[List[int]]:
"""
Args:
tokens_list:
A list of token list, each corresponding to one utterance.
Returns:
Return a list of token id list [utterance][token_id]
"""
token_ids_list = []
for tokens in tokens_list:
token_ids = []
for t in tokens:
if t not in self.token2id:
logging.warning(f"Skip OOV {t}")
continue
token_ids.append(self.token2id[t])
token_ids_list.append(token_ids)
return token_ids_list
if __name__ == "__main__":
text = "我们是5年小米人,是吗? Yes I think so! mr king, 5 years, from 2019 to 2024. 霍...啦啦啦超过90%的人咯...?!9204"
tokenizer = Tokenizer()
tokens = tokenizer.texts_to_tokens([text])
print(f"tokens : {tokens}")
tokens2 = "|".join(tokens[0])
print(f"tokens2 : {tokens2}")
tokens2 = tokens2.split("|")
assert tokens[0] == tokens2

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,456 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao,
# Zengrui Jin,
# Han Zhu,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import torch
from feature import TorchAudioFbank, TorchAudioFbankConfig
from lhotse import CutSet, load_manifest_lazy, validate
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
)
from lhotse.dataset.collation import collate_audio
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
BatchIO,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed, ifnone
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
SAMPLING_RATE = 24000
class TtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank_emilia"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=100,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=8,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=True,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
sampling_rate = SAMPLING_RATE
config = TorchAudioFbankConfig(
sampling_rate=sampling_rate,
n_mels=100,
n_fft=1024,
hop_length=256,
)
train = SpeechSynthesisDataset(
return_text=True,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = SAMPLING_RATE
config = TorchAudioFbankConfig(
sampling_rate=sampling_rate,
n_mels=100,
n_fft=1024,
hop_length=256,
)
validate = SpeechSynthesisDataset(
return_text=True,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)),
return_cuts=self.args.return_cuts,
)
else:
validate = SpeechSynthesisDataset(
return_text=True,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
dev_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create valid dataloader")
dev_dl = DataLoader(
validate,
sampler=dev_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return dev_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = SAMPLING_RATE
config = TorchAudioFbankConfig(
sampling_rate=sampling_rate,
n_mels=100,
n_fft=1024,
hop_length=256,
)
test = SpeechSynthesisDataset(
return_text=True,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)),
return_cuts=self.args.return_cuts,
return_audio=True,
)
else:
test = SpeechSynthesisDataset(
return_text=True,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
return_audio=True,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_emilia_EN_cuts(self) -> CutSet:
logging.info("About to get train the EN subset")
return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_EN.jsonl.gz")
@lru_cache()
def train_emilia_ZH_cuts(self) -> CutSet:
logging.info("About to get train the ZH subset")
return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_ZH.jsonl.gz")
@lru_cache()
def dev_emilia_EN_cuts(self) -> CutSet:
logging.info("About to get dev the EN subset")
return load_manifest_lazy(
self.args.manifest_dir / "emilia_cuts_EN-dev.jsonl.gz"
)
@lru_cache()
def dev_emilia_ZH_cuts(self) -> CutSet:
logging.info("About to get dev the ZH subset")
return load_manifest_lazy(
self.args.manifest_dir / "emilia_cuts_ZH-dev.jsonl.gz"
)
@lru_cache()
def train_libritts_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_with_tokens_train-all-shuf.jsonl.gz"
)
@lru_cache()
def dev_libritts_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libritts_cuts_with_tokens_dev-clean.jsonl.gz"
)
class SpeechSynthesisDataset(torch.utils.data.Dataset):
"""
The PyTorch Dataset for the speech synthesis task.
Each item in this dataset is a dict of:
.. code-block::
{
'audio': (B x NumSamples) float tensor
'features': (B x NumFrames x NumFeatures) float tensor
'audio_lens': (B, ) int tensor
'features_lens': (B, ) int tensor
'text': List[str] of len B # when return_text=True
'tokens': List[List[str]] # when return_tokens=True
'speakers': List[str] of len B # when return_spk_ids=True
'cut': List of Cuts # when return_cuts=True
}
"""
def __init__(
self,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
feature_input_strategy: BatchIO = PrecomputedFeatures(),
feature_transforms: Union[Sequence[Callable], Callable] = None,
return_text: bool = True,
return_tokens: bool = False,
return_spk_ids: bool = False,
return_cuts: bool = False,
return_audio: bool = False,
) -> None:
super().__init__()
self.cut_transforms = ifnone(cut_transforms, [])
self.feature_input_strategy = feature_input_strategy
self.return_text = return_text
self.return_tokens = return_tokens
self.return_spk_ids = return_spk_ids
self.return_cuts = return_cuts
self.return_audio = return_audio
if feature_transforms is None:
feature_transforms = []
elif not isinstance(feature_transforms, Sequence):
feature_transforms = [feature_transforms]
assert all(
isinstance(transform, Callable) for transform in feature_transforms
), "Feature transforms must be Callable"
self.feature_transforms = feature_transforms
def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
validate_for_tts(cuts)
for transform in self.cut_transforms:
cuts = transform(cuts)
features, features_lens = self.feature_input_strategy(cuts)
for transform in self.feature_transforms:
features = transform(features)
batch = {
"features": features,
"features_lens": features_lens,
}
if self.return_audio:
audio, audio_lens = collate_audio(cuts)
batch["audio"] = audio
batch["audio_lens"] = audio_lens
if self.return_text:
# use normalized text
text = [cut.supervisions[0].normalized_text for cut in cuts]
batch["text"] = text
if self.return_tokens:
tokens = [cut.tokens for cut in cuts]
batch["tokens"] = tokens
if self.return_spk_ids:
batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
if self.return_cuts:
batch["cut"] = [cut for cut in cuts]
return batch
def validate_for_tts(cuts: CutSet) -> None:
validate(cuts)
for cut in cuts:
assert (
len(cut.supervisions) == 1
), "Only the Cuts with single supervision are supported."

View File

@ -0,0 +1,219 @@
from typing import Any, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
class AttributeDict(dict):
def __getattr__(self, key):
if key in self:
return self[key]
raise AttributeError(f"No such attribute '{key}'")
def __setattr__(self, key, value):
self[key] = value
def __delattr__(self, key):
if key in self:
del self[key]
return
raise AttributeError(f"No such attribute '{key}'")
def prepare_input(
params: AttributeDict,
batch: dict,
device: torch.device,
tokenizer: Optional[Any] = None,
return_tokens: bool = False,
return_feature: bool = False,
return_audio: bool = False,
return_prompt: bool = False,
):
"""
Parse the features and targets of the current batch.
Args:
params:
It is returned by :func:`get_params`.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
sp:
Used to convert text to bpe tokens.
device:
The device of Tensor.
"""
return_list = []
if return_tokens:
assert tokenizer is not None
if params.token_type == "phone":
tokens = tokenizer.tokens_to_token_ids(batch["tokens"])
else:
tokens = tokenizer.texts_to_token_ids(batch["text"])
return_list += [tokens]
if return_feature:
features = batch["features"].to(device)
features_lens = batch["features_lens"].to(device)
return_list += [features * params.feat_scale, features_lens]
if return_audio:
return_list += [batch["audio"], batch["audio_lens"]]
if return_prompt:
if return_tokens:
if params.token_type == "phone":
prompt_tokens = tokenizer.tokens_to_token_ids(batch["prompt"]["tokens"])
else:
prompt_tokens = tokenizer.texts_to_token_ids(batch["prompt"]["text"])
return_list += [prompt_tokens]
if return_feature:
prompt_features = batch["prompt"]["features"].to(device)
prompt_features_lens = batch["prompt"]["features_lens"].to(device)
return_list += [prompt_features * params.feat_scale, prompt_features_lens]
if return_audio:
return_list += [batch["prompt"]["audio"], batch["prompt"]["audio_lens"]]
return return_list
def prepare_avg_tokens_durations(features_lens, tokens_lens):
tokens_durations = []
for i in range(len(features_lens)):
utt_duration = features_lens[i]
avg_token_duration = utt_duration // tokens_lens[i]
tokens_durations.append([avg_token_duration] * tokens_lens[i])
return tokens_durations
def pad_labels(y: List[List[int]], pad_id: int, device: torch.device):
"""
Pad the transcripts to the same length with zeros.
Args:
y: the transcripts, which is a list of a list
Returns:
Return a Tensor of padded transcripts.
"""
y = [l + [pad_id] for l in y]
length = max([len(l) for l in y])
y = [l + [pad_id] * (length - len(l)) for l in y]
return torch.tensor(y, dtype=torch.int64, device=device)
def get_tokens_index(durations: List[List[int]], num_frames: int) -> torch.Tensor:
"""
Gets position in the transcript for each frame, i.e. the position
in the symbol-sequence to look up.
Args:
durations:
Duration of each token in transcripts.
num_frames:
The maximum frame length of the current batch.
Returns:
Return a Tensor of shape (batch_size, num_frames)
"""
durations = [x + [num_frames - sum(x)] for x in durations]
batch_size = len(durations)
ans = torch.zeros(batch_size, num_frames, dtype=torch.int64)
for b in range(batch_size):
this_dur = durations[b]
cur_frame = 0
for i, d in enumerate(this_dur):
ans[b, cur_frame : cur_frame + d] = i
cur_frame += d
assert cur_frame == num_frames, (cur_frame, num_frames)
return ans
def to_int_tuple(s: str):
return tuple(map(int, s.split(",")))
def get_adjusted_batch_count(params: AttributeDict) -> float:
# returns the number of batches we would have used so far if we had used the reference
# duration. This is for purposes of set_batch_count().
return (
params.batch_idx_train
* (params.max_duration * params.world_size)
/ params.ref_duration
)
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
for name, module in model.named_modules():
if hasattr(module, "batch_count"):
module.batch_count = batch_count
if hasattr(module, "name"):
module.name = name
def condition_time_mask(
features_lens: torch.Tensor, mask_percent: Tuple[float, float], max_len: int = 0
) -> torch.Tensor:
"""
Apply Time masking.
Args:
features_lens:
input tensor of shape ``(B)``
mask_size:
the width size for masking.
max_len:
the maximum length of the mask.
Returns:
Return a 2-D bool tensor (B, T), where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
"""
mask_size = (
torch.zeros_like(features_lens, dtype=torch.float32).uniform_(*mask_percent)
* features_lens
).to(torch.int64)
mask_starts = (
torch.rand_like(mask_size, dtype=torch.float32) * (features_lens - mask_size)
).to(torch.int64)
mask_ends = mask_starts + mask_size
max_len = max(max_len, features_lens.max())
seq_range = torch.arange(0, max_len, device=features_lens.device)
mask = (seq_range[None, :] >= mask_starts[:, None]) & (
seq_range[None, :] < mask_ends[:, None]
)
return mask
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
max_len:
The length of masks.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
>>> lengths = torch.tensor([1, 3, 2, 5])
>>> make_pad_mask(lengths)
tensor([[False, True, True, True, True],
[False, False, False, True, True],
[False, False, True, True, True],
[False, False, False, False, False]])
"""
assert lengths.ndim == 1, lengths.ndim
max_len = max(max_len, lengths.max())
n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
return expaned_lengths >= lengths.unsqueeze(-1)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,642 @@
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script generates speech with our pre-trained ZipVoice or
ZipVoice-Distill models. Required models will be automatically
downloaded from HuggingFace.
Usage:
Note: If you having trouble connecting to HuggingFace,
you try switch endpoint to mirror site:
export HF_ENDPOINT=https://hf-mirror.com
(1) Inference of a single sentence:
python3 zipvoice/zipvoice_infer.py \
--model-name "zipvoice_distill" \
--prompt-wav prompt.wav \
--prompt-text "I am a prompt." \
--text "I am a sentence." \
--res-wav-path result.wav
(2) Inference of a list of sentences:
python3 zipvoice/zipvoice_infer.py \
--model-name "zipvoice-distill" \
--test-list test.tsv \
--res-dir results
`--model-name` can be `zipvoice` or `zipvoice_distill`,
which are the models before and after distillation, respectively.
Each line of `test.tsv` is in the format of
`{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
"""
import argparse
import datetime as dt
import os
import numpy as np
import safetensors.torch
import soundfile as sf
import torch
import torch.nn as nn
import torchaudio
from feature import TorchAudioFbank, TorchAudioFbankConfig
from huggingface_hub import hf_hub_download
from lhotse.utils import fix_random_seed
from model import get_distill_model, get_model
from tokenizer import TokenizerEmilia
from utils import AttributeDict
from vocos import Vocos
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-name",
type=str,
default="zipvoice_distill",
choices=["zipvoice", "zipvoice_distill"],
help="The model used for inference",
)
parser.add_argument(
"--test-list",
type=str,
default=None,
help="The list of prompt speech, prompt_transcription, "
"and text to synthesizein the format of "
"'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
)
parser.add_argument(
"--prompt-wav",
type=str,
default=None,
help="The prompt wav to mimic",
)
parser.add_argument(
"--prompt-text",
type=str,
default=None,
help="The transcription of the prompt wav",
)
parser.add_argument(
"--text",
type=str,
default=None,
help="The text to synthesize",
)
parser.add_argument(
"--res-dir",
type=str,
default="results",
help="Path name of the generated wavs dir, "
"used when decdode-list is not None",
)
parser.add_argument(
"--res-wav-path",
type=str,
default="result.wav",
help="Path name of the generated wav path, " "used when decdode-list is None",
)
parser.add_argument(
"--guidance-scale",
type=float,
default=None,
help="The scale of classifier-free guidance during inference.",
)
parser.add_argument(
"--num-step",
type=int,
default=None,
help="The number of sampling steps.",
)
parser.add_argument(
"--feat-scale",
type=float,
default=0.1,
help="The scale factor of fbank feature",
)
parser.add_argument(
"--speed",
type=float,
default=1.0,
help="Control speech speed, 1.0 means normal, >1.0 means speed up",
)
parser.add_argument(
"--t-shift",
type=float,
default=0.5,
help="Shift t to smaller ones if t_shift < 1.0",
)
parser.add_argument(
"--target-rms",
type=float,
default=0.1,
help="Target speech normalization rms value",
)
parser.add_argument(
"--seed",
type=int,
default=666,
help="Random seed",
)
add_model_arguments(parser)
return parser
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--fm-decoder-downsampling-factor",
type=str,
default="1,2,4,2,1",
help="Downsampling factor for each stack of encoder layers.",
)
parser.add_argument(
"--fm-decoder-num-layers",
type=str,
default="2,2,4,4,4",
help="Number of zipformer encoder layers per stack, comma separated.",
)
parser.add_argument(
"--fm-decoder-cnn-module-kernel",
type=str,
default="31,15,7,15,31",
help="Sizes of convolutional kernels in convolution modules "
"in each encoder stack: a single int or comma-separated list.",
)
parser.add_argument(
"--fm-decoder-feedforward-dim",
type=int,
default=1536,
help="Feedforward dimension of the zipformer encoder layers, "
"per stack, comma separated.",
)
parser.add_argument(
"--fm-decoder-num-heads",
type=int,
default=4,
help="Number of attention heads in the zipformer encoder layers: "
"a single int or comma-separated list.",
)
parser.add_argument(
"--fm-decoder-dim",
type=int,
default=512,
help="Embedding dimension in encoder stacks: a single int "
"or comma-separated list.",
)
parser.add_argument(
"--text-encoder-downsampling-factor",
type=str,
default="1",
help="Downsampling factor for each stack of encoder layers.",
)
parser.add_argument(
"--text-encoder-num-layers",
type=str,
default="4",
help="Number of zipformer encoder layers per stack, comma separated.",
)
parser.add_argument(
"--text-encoder-feedforward-dim",
type=int,
default=512,
help="Feedforward dimension of the zipformer encoder layers, "
"per stack, comma separated.",
)
parser.add_argument(
"--text-encoder-cnn-module-kernel",
type=str,
default="9",
help="Sizes of convolutional kernels in convolution modules in "
"each encoder stack: a single int or comma-separated list.",
)
parser.add_argument(
"--text-encoder-num-heads",
type=int,
default=4,
help="Number of attention heads in the zipformer encoder layers: "
"a single int or comma-separated list.",
)
parser.add_argument(
"--text-encoder-dim",
type=int,
default=192,
help="Embedding dimension in encoder stacks: "
"a single int or comma-separated list.",
)
parser.add_argument(
"--query-head-dim",
type=int,
default=32,
help="Query/key dimension per head in encoder stacks: "
"a single int or comma-separated list.",
)
parser.add_argument(
"--value-head-dim",
type=int,
default=12,
help="Value dimension per head in encoder stacks: "
"a single int or comma-separated list.",
)
parser.add_argument(
"--pos-head-dim",
type=int,
default=4,
help="Positional-encoding dimension per head in encoder stacks: "
"a single int or comma-separated list.",
)
parser.add_argument(
"--pos-dim",
type=int,
default=48,
help="Positional-encoding embedding dimension",
)
parser.add_argument(
"--time-embed-dim",
type=int,
default=192,
help="Embedding dimension of timestamps embedding.",
)
parser.add_argument(
"--text-embed-dim",
type=int,
default=192,
help="Embedding dimension of text embedding.",
)
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sampling_rate": 24000,
"frame_shift_ms": 256 / 24000 * 1000,
"feat_dim": 100,
}
)
return params
def get_vocoder():
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
return vocoder
def generate_sentence(
save_path: str,
prompt_text: str,
prompt_wav: str,
text: str,
model: nn.Module,
vocoder: nn.Module,
tokenizer: TokenizerEmilia,
feature_extractor: TorchAudioFbank,
device: torch.device,
num_step: int = 16,
guidance_scale: float = 1.0,
speed: float = 1.0,
t_shift: float = 0.5,
target_rms: float = 0.1,
feat_scale: float = 0.1,
sampling_rate: int = 24000,
):
"""
Generate waveform of a text based on a given prompt
waveform and its transcription.
Args:
save_path (str): Path to save the generated wav.
prompt_text (str): Transcription of the prompt wav.
prompt_wav (str): Path to the prompt wav file.
text (str): Text to be synthesized into a waveform.
model (nn.Module): The model used for generation.
vocoder (nn.Module): The vocoder used to convert features to waveforms.
tokenizer (TokenizerEmilia): The tokenizer used to convert text to tokens.
feature_extractor (TorchAudioFbank): The feature extractor used to
extract acoustic features.
device (torch.device): The device on which computations are performed.
num_step (int, optional): Number of steps for decoding. Defaults to 16.
guidance_scale (float, optional): Scale for classifier-free guidance.
Defaults to 1.0.
speed (float, optional): Speed control. Defaults to 1.0.
t_shift (float, optional): Time shift. Defaults to 0.5.
target_rms (float, optional): Target RMS for waveform normalization.
Defaults to 0.1.
feat_scale (float, optional): Scale for features.
Defaults to 0.1.
sampling_rate (int, optional): Sampling rate for the waveform.
Defaults to 24000.
Returns:
metrics (dict): Dictionary containing time and real-time
factor metrics for processing.
"""
# Convert text to tokens
tokens = tokenizer.texts_to_token_ids([text])
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
# Load and preprocess prompt wav
prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
if prompt_rms < target_rms:
prompt_wav = prompt_wav * target_rms / prompt_rms
if prompt_sampling_rate != sampling_rate:
resampler = torchaudio.transforms.Resample(
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
)
prompt_wav = resampler(prompt_wav)
# Extract features from prompt wav
prompt_features = feature_extractor.extract(
prompt_wav, sampling_rate=sampling_rate
).to(device)
prompt_features = prompt_features.unsqueeze(0) * feat_scale
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
# Start timing
start_t = dt.datetime.now()
# Generate features
(
pred_features,
pred_features_lens,
pred_prompt_features,
pred_prompt_features_lens,
) = model.sample(
tokens=tokens,
prompt_tokens=prompt_tokens,
prompt_features=prompt_features,
prompt_features_lens=prompt_features_lens,
speed=speed,
t_shift=t_shift,
duration="predict",
num_step=num_step,
guidance_scale=guidance_scale,
)
# Postprocess predicted features
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
# Start vocoder processing
start_vocoder_t = dt.datetime.now()
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
# Calculate processing times and real-time factors
t = (dt.datetime.now() - start_t).total_seconds()
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
wav_seconds = wav.shape[-1] / sampling_rate
rtf = t / wav_seconds
rtf_no_vocoder = t_no_vocoder / wav_seconds
rtf_vocoder = t_vocoder / wav_seconds
metrics = {
"t": t,
"t_no_vocoder": t_no_vocoder,
"t_vocoder": t_vocoder,
"wav_seconds": wav_seconds,
"rtf": rtf,
"rtf_no_vocoder": rtf_no_vocoder,
"rtf_vocoder": rtf_vocoder,
}
# Adjust wav volume if necessary
if prompt_rms < target_rms:
wav = wav * prompt_rms / target_rms
wav = wav[0].cpu().numpy()
sf.write(save_path, wav, sampling_rate)
return metrics
def generate(
res_dir: str,
test_list: str,
model: nn.Module,
vocoder: nn.Module,
tokenizer: TokenizerEmilia,
feature_extractor: TorchAudioFbank,
device: torch.device,
num_step: int = 16,
guidance_scale: float = 1.0,
speed: float = 1.0,
t_shift: float = 0.5,
target_rms: float = 0.1,
feat_scale: float = 0.1,
sampling_rate: int = 24000,
):
total_t = []
total_t_no_vocoder = []
total_t_vocoder = []
total_wav_seconds = []
with open(test_list, "r") as fr:
lines = fr.readlines()
for i, line in enumerate(lines):
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
save_path = f"{res_dir}/{wav_name}.wav"
metrics = generate_sentence(
save_path=save_path,
prompt_text=prompt_text,
prompt_wav=prompt_wav,
text=text,
model=model,
vocoder=vocoder,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
device=device,
num_step=num_step,
guidance_scale=guidance_scale,
speed=speed,
t_shift=t_shift,
target_rms=target_rms,
feat_scale=feat_scale,
sampling_rate=sampling_rate,
)
print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
total_t.append(metrics["t"])
total_t_no_vocoder.append(metrics["t_no_vocoder"])
total_t_vocoder.append(metrics["t_vocoder"])
total_wav_seconds.append(metrics["wav_seconds"])
print(f"Average RTF: {np.sum(total_t)/np.sum(total_wav_seconds):.4f}")
print(
f"Average RTF w/o vocoder: "
f"{np.sum(total_t_no_vocoder)/np.sum(total_wav_seconds):.4f}"
)
print(
f"Average RTF vocoder: "
f"{np.sum(total_t_vocoder)/np.sum(total_wav_seconds):.4f}"
)
@torch.inference_mode()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
model_defaults = {
"zipvoice": {
"num_step": 16,
"guidance_scale": 1.0,
},
"zipvoice_distill": {
"num_step": 8,
"guidance_scale": 3.0,
},
}
model_specific_defaults = model_defaults.get(params.model_name, {})
for param, value in model_specific_defaults.items():
if getattr(params, param) == parser.get_default(param):
setattr(params, param, value)
print(f"Setting {param} to default value: {value}")
assert (params.test_list is not None) ^ (
(params.prompt_wav and params.prompt_text and params.text) is not None
), (
"For inference, please provide prompts and text with either '--test-list'"
" or '--prompt-wav, --prompt-text and --text'."
)
if torch.cuda.is_available():
params.device = torch.device("cuda", 0)
else:
params.device = torch.device("cpu")
token_file = hf_hub_download("zhu-han/ZipVoice", filename="tokens_emilia.txt")
tokenizer = TokenizerEmilia(token_file)
params.vocab_size = tokenizer.vocab_size
params.pad_id = tokenizer.pad_id
fix_random_seed(params.seed)
if params.model_name == "zipvoice_distill":
model = get_distill_model(params)
model_ckpt = hf_hub_download(
"zhu-han/ZipVoice", filename="exp_zipvoice_distill/model.safetensors"
)
else:
model = get_model(params)
model_ckpt = hf_hub_download(
"zhu-han/ZipVoice", filename="exp_zipvoice/model.safetensors"
)
safetensors.torch.load_model(model, model_ckpt)
model = model.to(params.device)
model.eval()
vocoder = get_vocoder()
vocoder = vocoder.to(params.device)
vocoder.eval()
config = TorchAudioFbankConfig(
sampling_rate=params.sampling_rate,
n_mels=100,
n_fft=1024,
hop_length=256,
)
feature_extractor = TorchAudioFbank(config)
if params.test_list:
os.makedirs(params.res_dir, exist_ok=True)
generate(
res_dir=params.res_dir,
test_list=params.test_list,
model=model,
vocoder=vocoder,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
device=params.device,
num_step=params.num_step,
guidance_scale=params.guidance_scale,
speed=params.speed,
t_shift=params.t_shift,
target_rms=params.target_rms,
feat_scale=params.feat_scale,
sampling_rate=params.sampling_rate,
)
else:
generate_sentence(
save_path=params.res_wav_path,
prompt_text=params.prompt_text,
prompt_wav=params.prompt_wav,
text=params.text,
model=model,
vocoder=vocoder,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
device=params.device,
num_step=params.num_step,
guidance_scale=params.guidance_scale,
speed=params.speed,
t_shift=params.t_shift,
target_rms=params.target_rms,
feat_scale=params.feat_scale,
sampling_rate=params.sampling_rate,
)
print("Done")
if __name__ == "__main__":
main()