mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
commit
8c529ebe90
360
egs/zipvoice/README.md
Normal file
360
egs/zipvoice/README.md
Normal file
@ -0,0 +1,360 @@
|
|||||||
|
## ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching
|
||||||
|
|
||||||
|
|
||||||
|
[](http://arxiv.org/abs/2506.13053)
|
||||||
|
[](https://zipvoice.github.io/)
|
||||||
|
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
ZipVoice is a high-quality zero-shot TTS model with a small model size and fast inference speed.
|
||||||
|
#### Key features:
|
||||||
|
|
||||||
|
- Small and fast: only 123M parameters.
|
||||||
|
|
||||||
|
- High-quality: state-of-the-art voice cloning performance in speaker similarity, intelligibility, and naturalness.
|
||||||
|
|
||||||
|
- Multi-lingual: support Chinese and English.
|
||||||
|
|
||||||
|
|
||||||
|
## News
|
||||||
|
**2025/06/16**: 🔥 ZipVoice is released.
|
||||||
|
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To generate speech with our pre-trained ZipVoice or ZipVoice-Distill models, use the following commands (Required models will be downloaded from HuggingFace):
|
||||||
|
|
||||||
|
### 1. Inference of a single sentence:
|
||||||
|
```bash
|
||||||
|
python3 zipvoice/zipvoice_infer.py \
|
||||||
|
--model-name "zipvoice_distill" \
|
||||||
|
--prompt-wav prompt.wav \
|
||||||
|
--prompt-text "I am the transcription of the prompt wav." \
|
||||||
|
--text "I am the text to be synthesized." \
|
||||||
|
--res-wav-path result.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Inference of a list of sentences:
|
||||||
|
```bash
|
||||||
|
python3 zipvoice/zipvoice_infer.py \
|
||||||
|
--model-name "zipvoice_distill" \
|
||||||
|
--test-list test.tsv \
|
||||||
|
--res-dir results/test
|
||||||
|
```
|
||||||
|
- `--model-name` can be `zipvoice` or `zipvoice_distill`, which are models before and after distillation, respectively.
|
||||||
|
- Each line of `test.tsv` is in the format of `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
|
||||||
|
|
||||||
|
|
||||||
|
> **Note:** If you having trouble connecting to HuggingFace, try:
|
||||||
|
```bash
|
||||||
|
export HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training Your Own Model
|
||||||
|
|
||||||
|
The following steps show how to train a model from scratch on Emilia and LibriTTS datasets, respectively.
|
||||||
|
|
||||||
|
### 1. Data Preparation
|
||||||
|
|
||||||
|
#### 1.1. Prepare the Emilia dataset
|
||||||
|
|
||||||
|
#### 1.2 Prepare the LibriTTS dataset
|
||||||
|
|
||||||
|
See [local/prepare_libritts.sh](local/prepare_libritts.sh)
|
||||||
|
|
||||||
|
### 2. Training
|
||||||
|
|
||||||
|
#### 2.1 Traininig on Emilia
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Expand to view training steps</summary>
|
||||||
|
|
||||||
|
##### 2.1.1 Train the ZipVoice model
|
||||||
|
|
||||||
|
- Training:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
python3 zipvoice/train_flow.py \
|
||||||
|
--world-size 8 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--dataset emilia \
|
||||||
|
--max-duration 500 \
|
||||||
|
--lr-hours 30000 \
|
||||||
|
--lr-batches 7500 \
|
||||||
|
--token-file "data/tokens_emilia.txt" \
|
||||||
|
--manifest-dir "data/fbank_emilia" \
|
||||||
|
--num-epochs 11 \
|
||||||
|
--exp-dir zipvoice/exp_zipvoice
|
||||||
|
```
|
||||||
|
|
||||||
|
- Average the checkpoints to produce the final model:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
python3 zipvoice/generate_averaged_model.py \
|
||||||
|
--epoch 11 \
|
||||||
|
--avg 4 \
|
||||||
|
--distill 0 \
|
||||||
|
--token-file data/tokens_emilia.txt \
|
||||||
|
--dataset "emilia" \
|
||||||
|
--exp-dir ./zipvoice/exp_zipvoice
|
||||||
|
# The generated model is zipvoice/exp_zipvoice/epoch-11-avg-4.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
##### 2.1.2. Train the ZipVoice-Distill model (Optional)
|
||||||
|
|
||||||
|
- The first-stage distillation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
python3 zipvoice/train_distill.py \
|
||||||
|
--world-size 8 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--tensorboard 1 \
|
||||||
|
--dataset "emilia" \
|
||||||
|
--base-lr 0.0005 \
|
||||||
|
--max-duration 500 \
|
||||||
|
--token-file "data/tokens_emilia.txt" \
|
||||||
|
--manifest-dir "data/fbank_emilia" \
|
||||||
|
--teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
|
||||||
|
--num-updates 60000 \
|
||||||
|
--distill-stage "first" \
|
||||||
|
--exp-dir zipvoice/exp_zipvoice_distill_1stage
|
||||||
|
```
|
||||||
|
|
||||||
|
- Average checkpoints for the second-stage initialization:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
python3 zipvoice/generate_averaged_model.py \
|
||||||
|
--iter 60000 \
|
||||||
|
--avg 7 \
|
||||||
|
--distill 1 \
|
||||||
|
--token-file data/tokens_emilia.txt \
|
||||||
|
--dataset "emilia" \
|
||||||
|
--exp-dir ./zipvoice/exp_zipvoice_distill_1stage
|
||||||
|
# The generated model is zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
- The second-stage distillation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
python3 zipvoice/train_distill.py \
|
||||||
|
--world-size 8 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--tensorboard 1 \
|
||||||
|
--dataset "emilia" \
|
||||||
|
--base-lr 0.0001 \
|
||||||
|
--max-duration 200 \
|
||||||
|
--token-file "data/tokens_emilia.txt" \
|
||||||
|
--manifest-dir "data/fbank_emilia" \
|
||||||
|
--teacher-model zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt \
|
||||||
|
--num-updates 2000 \
|
||||||
|
--distill-stage "second" \
|
||||||
|
--exp-dir zipvoice/exp_zipvoice_distill_new
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
#### 2.2 Traininig on LibriTTS
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Expand to view training steps</summary>
|
||||||
|
|
||||||
|
##### 2.2.1 Train the ZipVoice model
|
||||||
|
|
||||||
|
- Training:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
python3 zipvoice/train_flow.py \
|
||||||
|
--world-size 8 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--dataset libritts \
|
||||||
|
--max-duration 250 \
|
||||||
|
--lr-epochs 10 \
|
||||||
|
--lr-batches 7500 \
|
||||||
|
--token-file "data/tokens_libritts.txt" \
|
||||||
|
--manifest-dir "data/fbank_libritts" \
|
||||||
|
--num-epochs 60 \
|
||||||
|
--exp-dir zipvoice/exp_zipvoice_libritts
|
||||||
|
```
|
||||||
|
|
||||||
|
- Average the checkpoints to produce the final model:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
python3 zipvoice/generate_averaged_model.py \
|
||||||
|
--epoch 60 \
|
||||||
|
--avg 10 \
|
||||||
|
--distill 0 \
|
||||||
|
--token-file data/tokens_libritts.txt \
|
||||||
|
--dataset "libritts" \
|
||||||
|
--exp-dir ./zipvoice/exp_zipvoice_libritts
|
||||||
|
# The generated model is zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
##### 2.1.2 Train the ZipVoice-Distill model (Optional)
|
||||||
|
|
||||||
|
- The first-stage distillation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
python3 zipvoice/train_distill.py \
|
||||||
|
--world-size 8 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--tensorboard 1 \
|
||||||
|
--dataset "libritts" \
|
||||||
|
--base-lr 0.001 \
|
||||||
|
--max-duration 250 \
|
||||||
|
--token-file "data/tokens_libritts.txt" \
|
||||||
|
--manifest-dir "data/fbank_libritts" \
|
||||||
|
--teacher-model zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt \
|
||||||
|
--num-epochs 6 \
|
||||||
|
--distill-stage "first" \
|
||||||
|
--exp-dir zipvoice/exp_zipvoice_distill_1stage_libritts
|
||||||
|
```
|
||||||
|
|
||||||
|
- Average checkpoints for the second-stage initialization:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
python3 ./zipvoice/generate_averaged_model.py \
|
||||||
|
--epoch 6 \
|
||||||
|
--avg 3 \
|
||||||
|
--distill 1 \
|
||||||
|
--token-file data/tokens_libritts.txt \
|
||||||
|
--dataset "libritts" \
|
||||||
|
--exp-dir ./zipvoice/exp_zipvoice_distill_1stage_libritts
|
||||||
|
# The generated model is zipvoice/exp_zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
- The second-stage distillation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
python3 zipvoice/train_distill.py \
|
||||||
|
--world-size 8 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--tensorboard 1 \
|
||||||
|
--dataset "libritts" \
|
||||||
|
--base-lr 0.001 \
|
||||||
|
--max-duration 250 \
|
||||||
|
--token-file "data/tokens_libritts.txt" \
|
||||||
|
--manifest-dir "data/fbank_libritts" \
|
||||||
|
--teacher-model zipvoice/exp_zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt \
|
||||||
|
--num-epochs 6 \
|
||||||
|
--distill-stage "second" \
|
||||||
|
--exp-dir zipvoice/exp_zipvoice_distill_libritts
|
||||||
|
```
|
||||||
|
|
||||||
|
- Average checkpoints to produce the final model:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
python3 ./zipvoice/generate_averaged_model.py \
|
||||||
|
--epoch 6 \
|
||||||
|
--avg 3 \
|
||||||
|
--distill 1 \
|
||||||
|
--token-file data/tokens_libritts.txt \
|
||||||
|
--dataset "libritts" \
|
||||||
|
--exp-dir ./zipvoice/exp_zipvoice_distill_libritts
|
||||||
|
# The generated model is ./zipvoice/exp_zipvoice_distill_libritts/epoch-6-avg-3.pt
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
### 3. Inference with the trained model
|
||||||
|
|
||||||
|
#### 3.1 Inference with the model trained on Emilia
|
||||||
|
<details>
|
||||||
|
<summary>Expand to view inference commands.</summary>
|
||||||
|
|
||||||
|
##### 3.1.1 ZipVoice model before distill:
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
python3 zipvoice/infer.py \
|
||||||
|
--checkpoint zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
|
||||||
|
--distill 0 \
|
||||||
|
--token-file "data/tokens_emilia.txt" \
|
||||||
|
--test-list test.tsv \
|
||||||
|
--res-dir results/test \
|
||||||
|
--num-step 16 \
|
||||||
|
--guidance-scale 1
|
||||||
|
```
|
||||||
|
|
||||||
|
##### 3.1.2 ZipVoice-Distill model before distill:
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
python3 zipvoice/infer.py \
|
||||||
|
--checkpoint zipvoice/exp_zipvoice_distill/checkpoint-2000.pt \
|
||||||
|
--distill 1 \
|
||||||
|
--token-file "data/tokens_emilia.txt" \
|
||||||
|
--test-list test.tsv \
|
||||||
|
--res-dir results/test_distill \
|
||||||
|
--num-step 8 \
|
||||||
|
--guidance-scale 3
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
#### 3.2 Inference with the model trained on LibriTTS
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Expand to view inference commands.</summary>
|
||||||
|
|
||||||
|
##### 3.2.1 ZipVoice model before distill:
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
python3 zipvoice/infer.py \
|
||||||
|
--checkpoint zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt \
|
||||||
|
--distill 0 \
|
||||||
|
--token-file "data/tokens_libritts.txt" \
|
||||||
|
--test-list test.tsv \
|
||||||
|
--res-dir results/test_libritts \
|
||||||
|
--num-step 8 \
|
||||||
|
--guidance-scale 1 \
|
||||||
|
--target-rms 1.0 \
|
||||||
|
--t-shift 0.7
|
||||||
|
```
|
||||||
|
|
||||||
|
##### 3.2.2 ZipVoice-Distill model before distill
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
python3 zipvoice/infer.py \
|
||||||
|
--checkpoint zipvoice/exp_zipvoice_distill/epoch-6-avg-3.pt \
|
||||||
|
--distill 1 \
|
||||||
|
--token-file "data/tokens_libritts.txt" \
|
||||||
|
--test-list test.tsv \
|
||||||
|
--res-dir results/test_distill_libritts \
|
||||||
|
--num-step 4 \
|
||||||
|
--guidance-scale 3 \
|
||||||
|
--target-rms 1.0 \
|
||||||
|
--t-shift 0.7
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 4. Evaluation on benchmarks
|
||||||
|
|
||||||
|
See [local/evaluate.sh](local/evaluate.sh) for details of objective metrics evaluation
|
||||||
|
on three test sets, i.e., LibriSpeech-PC test-clean, Seed-TTS test-en and Seed-TTS test-zh.
|
||||||
|
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{zhu-2025-zipvoice,
|
||||||
|
title={ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching},
|
||||||
|
author={Han Zhu and Wei Kang and Zengwei Yao and Liyong Guo and Fangjun Kuang and Zhaoqing Li and Weiji Zhuang and Long Lin and Daniel Povey}
|
||||||
|
journal={arXiv preprint arXiv:2506.13053},
|
||||||
|
year={2025},
|
||||||
|
}
|
||||||
|
```
|
140
egs/zipvoice/local/compute_fbank_libritts.py
Executable file
140
egs/zipvoice/local/compute_fbank_libritts.py
Executable file
@ -0,0 +1,140 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Zengwei Yao,)
|
||||||
|
# 2024 The Chinese Univ. of HK (authors: Zengrui Jin)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file computes fbank features of the LibriTTS dataset.
|
||||||
|
It looks for manifests in the directory data/manifests.
|
||||||
|
|
||||||
|
The generated fbank features are saved in data/fbank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from feature import TorchAudioFbank, TorchAudioFbankConfig
|
||||||
|
from lhotse import CutSet, LilcomChunkyWriter
|
||||||
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
|
from icefall.utils import get_executor
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
help="""Dataset parts to compute fbank. If None, we will use all""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sampling-rate",
|
||||||
|
type=int,
|
||||||
|
default=24000,
|
||||||
|
help="""Sampling rate of the waveform for computing fbank,
|
||||||
|
the default value for LibriTTS is 24000, waveform files will be
|
||||||
|
resampled if a different sample rate is provided""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_fbank_libritts(dataset: Optional[str] = None, sampling_rate: int = 24000):
|
||||||
|
src_dir = Path("data/manifests_libritts")
|
||||||
|
output_dir = Path("data/fbank_libritts")
|
||||||
|
num_jobs = min(32, os.cpu_count())
|
||||||
|
|
||||||
|
prefix = "libritts"
|
||||||
|
suffix = "jsonl.gz"
|
||||||
|
if dataset is None:
|
||||||
|
dataset_parts = (
|
||||||
|
"dev-clean",
|
||||||
|
"test-clean",
|
||||||
|
"train-clean-100",
|
||||||
|
"train-clean-360",
|
||||||
|
"train-other-500",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dataset_parts = dataset.split(" ", -1)
|
||||||
|
|
||||||
|
manifests = read_manifests_if_cached(
|
||||||
|
dataset_parts=dataset_parts,
|
||||||
|
output_dir=src_dir,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
)
|
||||||
|
assert manifests is not None
|
||||||
|
|
||||||
|
assert len(manifests) == len(dataset_parts), (
|
||||||
|
len(manifests),
|
||||||
|
len(dataset_parts),
|
||||||
|
list(manifests.keys()),
|
||||||
|
dataset_parts,
|
||||||
|
)
|
||||||
|
|
||||||
|
config = TorchAudioFbankConfig(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
n_mels=100,
|
||||||
|
n_fft=1024,
|
||||||
|
hop_length=256,
|
||||||
|
)
|
||||||
|
extractor = TorchAudioFbank(config)
|
||||||
|
|
||||||
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
for partition, m in manifests.items():
|
||||||
|
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||||
|
if (output_dir / cuts_filename).is_file():
|
||||||
|
logging.info(f"{partition} already exists - skipping.")
|
||||||
|
return
|
||||||
|
logging.info(f"Processing {partition}")
|
||||||
|
cut_set = CutSet.from_manifests(
|
||||||
|
recordings=m["recordings"],
|
||||||
|
supervisions=m["supervisions"],
|
||||||
|
)
|
||||||
|
if sampling_rate != 24000:
|
||||||
|
logging.info(f"Resampling waveforms to {sampling_rate}")
|
||||||
|
cut_set = cut_set.resample(sampling_rate)
|
||||||
|
|
||||||
|
cut_set = cut_set.compute_and_store_features(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||||
|
# when an executor is specified, make more partitions
|
||||||
|
num_jobs=num_jobs if ex is None else 80,
|
||||||
|
executor=ex,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
)
|
||||||
|
cut_set.to_file(output_dir / cuts_filename)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
compute_fbank_libritts()
|
102
egs/zipvoice/local/evaluate.sh
Normal file
102
egs/zipvoice/local/evaluate.sh
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
export CUDA_VISIBLE_DEVICES="0"
|
||||||
|
export PYTHONWARNINGS=ignore
|
||||||
|
export PYTHONPATH=../../:$PYTHONPATH
|
||||||
|
|
||||||
|
# Uncomment this if you have trouble connecting to HuggingFace
|
||||||
|
# export HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
|
||||||
|
start_stage=1
|
||||||
|
end_stage=3
|
||||||
|
|
||||||
|
# Models used for SIM-o evaluation.
|
||||||
|
# SV model wavlm_large_finetune.pth is downloaded from https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification
|
||||||
|
# SSL model wavlm_large.pt is downloaded from https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt
|
||||||
|
sv_model_path=model/UniSpeech/wavlm_large_finetune.pth
|
||||||
|
wavlm_model_path=model/s3prl/wavlm_large.pt
|
||||||
|
|
||||||
|
# Models used for UTMOS evaluation.
|
||||||
|
# wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt
|
||||||
|
# wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt
|
||||||
|
utmos_model_path=model/huggingface/utmos/utmos.pt
|
||||||
|
wav2vec_model_path=model/huggingface/utmos/wav2vec_small.pt
|
||||||
|
|
||||||
|
|
||||||
|
if [ $start_stage -le 1 ] && [ $end_stage -ge 1 ]; then
|
||||||
|
|
||||||
|
echo "=====Evaluate for Seed-TTS test-en======="
|
||||||
|
test_list=testset/test_seedtts_en.tsv
|
||||||
|
wav_path=results/zipvoice_seedtts_en
|
||||||
|
|
||||||
|
echo $wav_path
|
||||||
|
echo "-----Computing SIM-o-----"
|
||||||
|
python3 local/evaluate_sim.py \
|
||||||
|
--sv-model-path ${sv_model_path} \
|
||||||
|
--ssl-model-path ${wavlm_model_path} \
|
||||||
|
--eval-path ${wav_path} \
|
||||||
|
--test-list ${test_list}
|
||||||
|
|
||||||
|
echo "-----Computing WER-----"
|
||||||
|
python3 local/evaluate_wer_seedtts.py \
|
||||||
|
--test-list ${test_list} \
|
||||||
|
--wav-path ${wav_path} \
|
||||||
|
--lang "en"
|
||||||
|
|
||||||
|
echo "-----Computing UTSMOS-----"
|
||||||
|
python3 local/evaluate_utmos.py \
|
||||||
|
--wav-path ${wav_path} \
|
||||||
|
--utmos-model-path ${utmos_model_path} \
|
||||||
|
--ssl-model-path ${wav2vec_model_path}
|
||||||
|
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $start_stage -le 2 ] && [ $end_stage -ge 2 ]; then
|
||||||
|
echo "=====Evaluate for Seed-TTS test-zh======="
|
||||||
|
test_list=testset/test_seedtts_zh.tsv
|
||||||
|
wav_path=results/zipvoice_seedtts_zh
|
||||||
|
|
||||||
|
echo $wav_path
|
||||||
|
echo "-----Computing SIM-o-----"
|
||||||
|
python3 local/evaluate_sim.py \
|
||||||
|
--sv-model-path ${sv_model_path} \
|
||||||
|
--ssl-model-path ${wavlm_model_path} \
|
||||||
|
--eval-path ${wav_path} \
|
||||||
|
--test-list ${test_list}
|
||||||
|
|
||||||
|
echo "-----Computing WER-----"
|
||||||
|
python3 local/evaluate_wer_seedtts.py \
|
||||||
|
--test-list ${test_list} \
|
||||||
|
--wav-path ${wav_path} \
|
||||||
|
--lang "zh"
|
||||||
|
|
||||||
|
echo "-----Computing UTSMOS-----"
|
||||||
|
python3 local/evaluate_utmos.py \
|
||||||
|
--wav-path ${wav_path} \
|
||||||
|
--utmos-model-path ${utmos_model_path} \
|
||||||
|
--ssl-model-path ${wav2vec_model_path}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $start_stage -le 3 ] && [ $end_stage -ge 3 ]; then
|
||||||
|
echo "=====Evaluate for Librispeech test-clean======="
|
||||||
|
test_list=testset/test_librispeech_pc_test_clean.tsv
|
||||||
|
wav_path=results/zipvoice_librispeech_test_clean
|
||||||
|
|
||||||
|
echo $wav_path
|
||||||
|
echo "-----Computing SIM-o-----"
|
||||||
|
python3 local/evaluate_sim.py \
|
||||||
|
--sv-model-path ${sv_model_path} \
|
||||||
|
--ssl-model-path ${wavlm_model_path} \
|
||||||
|
--eval-path ${wav_path} \
|
||||||
|
--test-list ${test_list}
|
||||||
|
|
||||||
|
echo "-----Computing WER-----"
|
||||||
|
python3 local/evaluate_wer_hubert.py \
|
||||||
|
--test-list ${test_list} \
|
||||||
|
--wav-path ${wav_path} \
|
||||||
|
|
||||||
|
echo "-----Computing UTSMOS-----"
|
||||||
|
python3 local/evaluate_utmos.py \
|
||||||
|
--wav-path ${wav_path} \
|
||||||
|
--utmos-model-path ${utmos_model_path} \
|
||||||
|
--ssl-model-path ${wav2vec_model_path}
|
||||||
|
|
||||||
|
fi
|
508
egs/zipvoice/local/evaluate_sim.py
Normal file
508
egs/zipvoice/local/evaluate_sim.py
Normal file
@ -0,0 +1,508 @@
|
|||||||
|
"""
|
||||||
|
Calculate pairwise Speaker Similarity betweeen two speech directories.
|
||||||
|
SV model wavlm_large_finetune.pth is downloaded from
|
||||||
|
https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification
|
||||||
|
SSL model wavlm_large.pt is downloaded from
|
||||||
|
https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval-path", type=str, help="path of the evaluated speech directory"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-list",
|
||||||
|
type=str,
|
||||||
|
help="path of the file list that contains the corresponding "
|
||||||
|
"relationship between the prompt and evaluated speech. "
|
||||||
|
"The first column is the wav name and the third column is the prompt speech",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sv-model-path",
|
||||||
|
type=str,
|
||||||
|
default="model/UniSpeech/wavlm_large_finetune.pth",
|
||||||
|
help="path of the wavlm-based ECAPA-TDNN model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ssl-model-path",
|
||||||
|
type=str,
|
||||||
|
default="model/s3prl/wavlm_large.pt",
|
||||||
|
help="path of the wavlm SSL model",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class SpeakerSimilarity:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sv_model_path="model/UniSpeech/wavlm_large_finetune.pth",
|
||||||
|
ssl_model_path="model/s3prl/wavlm_large.pt",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize
|
||||||
|
"""
|
||||||
|
self.sample_rate = 16000
|
||||||
|
self.channels = 1
|
||||||
|
self.device = (
|
||||||
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
)
|
||||||
|
logging.info("[Speaker Similarity] Using device: {}".format(self.device))
|
||||||
|
self.model = ECAPA_TDNN_WAVLLM(
|
||||||
|
feat_dim=1024,
|
||||||
|
channels=512,
|
||||||
|
emb_dim=256,
|
||||||
|
sr=16000,
|
||||||
|
ssl_model_path=ssl_model_path,
|
||||||
|
)
|
||||||
|
state_dict = torch.load(
|
||||||
|
sv_model_path, map_location=lambda storage, loc: storage
|
||||||
|
)
|
||||||
|
self.model.load_state_dict(state_dict["model"], strict=False)
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def get_embeddings(self, wav_list, dtype="float32"):
|
||||||
|
"""
|
||||||
|
Get embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _load_speech_task(fname, sample_rate):
|
||||||
|
|
||||||
|
wav_data, sr = sf.read(fname, dtype=dtype)
|
||||||
|
if sr != sample_rate:
|
||||||
|
wav_data = librosa.resample(
|
||||||
|
wav_data, orig_sr=sr, target_sr=self.sample_rate
|
||||||
|
)
|
||||||
|
wav_data = torch.from_numpy(wav_data)
|
||||||
|
|
||||||
|
return wav_data
|
||||||
|
|
||||||
|
embd_lst = []
|
||||||
|
for file_path in tqdm(wav_list):
|
||||||
|
speech = _load_speech_task(file_path, self.sample_rate)
|
||||||
|
speech = speech.to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
embd = self.model([speech])
|
||||||
|
embd_lst.append(embd)
|
||||||
|
|
||||||
|
return embd_lst
|
||||||
|
|
||||||
|
def score(
|
||||||
|
self,
|
||||||
|
eval_path,
|
||||||
|
test_list,
|
||||||
|
dtype="float32",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Computes the Speaker Similarity (SIM-o) between two directories of speech files.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- eval_path (str): Path to the directory containing evaluation speech files.
|
||||||
|
- test_list (str): Path to the file containing the corresponding relationship
|
||||||
|
between prompt and evaluated speech.
|
||||||
|
- dtype (str, optional): Data type for loading speech. Default is "float32".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- float: The Speaker Similarity (SIM-o) score between the two directories
|
||||||
|
of speech files.
|
||||||
|
"""
|
||||||
|
prompt_wavs = []
|
||||||
|
eval_wavs = []
|
||||||
|
with open(test_list, "r") as fr:
|
||||||
|
lines = fr.readlines()
|
||||||
|
for line in lines:
|
||||||
|
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
|
||||||
|
prompt_wavs.append(prompt_wav)
|
||||||
|
eval_wavs.append(os.path.join(eval_path, wav_name + ".wav"))
|
||||||
|
embds_prompt = self.get_embeddings(prompt_wavs, dtype=dtype)
|
||||||
|
|
||||||
|
embds_eval = self.get_embeddings(eval_wavs, dtype=dtype)
|
||||||
|
|
||||||
|
# Check if embeddings are empty
|
||||||
|
if len(embds_prompt) == 0:
|
||||||
|
logging.info("[Speaker Similarity] real set dir is empty, exiting...")
|
||||||
|
return -1
|
||||||
|
if len(embds_eval) == 0:
|
||||||
|
logging.info("[Speaker Similarity] eval set dir is empty, exiting...")
|
||||||
|
return -1
|
||||||
|
|
||||||
|
scores = []
|
||||||
|
for real_embd, eval_embd in zip(embds_prompt, embds_eval):
|
||||||
|
scores.append(
|
||||||
|
torch.nn.functional.cosine_similarity(real_embd, eval_embd, dim=-1)
|
||||||
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
|
||||||
|
return np.mean(scores)
|
||||||
|
|
||||||
|
|
||||||
|
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
||||||
|
|
||||||
|
""" Res2Conv1d + BatchNorm1d + ReLU
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Res2Conv1dReluBn(nn.Module):
|
||||||
|
"""
|
||||||
|
in_channels == out_channels == channels
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
bias=True,
|
||||||
|
scale=4,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
||||||
|
self.scale = scale
|
||||||
|
self.width = channels // scale
|
||||||
|
self.nums = scale if scale == 1 else scale - 1
|
||||||
|
|
||||||
|
self.convs = []
|
||||||
|
self.bns = []
|
||||||
|
for i in range(self.nums):
|
||||||
|
self.convs.append(
|
||||||
|
nn.Conv1d(
|
||||||
|
self.width,
|
||||||
|
self.width,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.bns.append(nn.BatchNorm1d(self.width))
|
||||||
|
self.convs = nn.ModuleList(self.convs)
|
||||||
|
self.bns = nn.ModuleList(self.bns)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = []
|
||||||
|
spx = torch.split(x, self.width, 1)
|
||||||
|
for i in range(self.nums):
|
||||||
|
if i == 0:
|
||||||
|
sp = spx[i]
|
||||||
|
else:
|
||||||
|
sp = sp + spx[i]
|
||||||
|
# Order: conv -> relu -> bn
|
||||||
|
sp = self.convs[i](sp)
|
||||||
|
sp = self.bns[i](F.relu(sp))
|
||||||
|
out.append(sp)
|
||||||
|
if self.scale != 1:
|
||||||
|
out.append(spx[self.nums])
|
||||||
|
out = torch.cat(out, dim=1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
""" Conv1d + BatchNorm1d + ReLU
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1dReluBn(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
bias=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
|
||||||
|
)
|
||||||
|
self.bn = nn.BatchNorm1d(out_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.bn(F.relu(self.conv(x)))
|
||||||
|
|
||||||
|
|
||||||
|
""" The SE connection of 1D case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class SE_Connect(nn.Module):
|
||||||
|
def __init__(self, channels, se_bottleneck_dim=128):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
||||||
|
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = x.mean(dim=2)
|
||||||
|
out = F.relu(self.linear1(out))
|
||||||
|
out = torch.sigmoid(self.linear2(out))
|
||||||
|
out = x * out.unsqueeze(2)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
||||||
|
# return nn.Sequential(
|
||||||
|
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
|
||||||
|
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
|
||||||
|
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
|
||||||
|
# SE_Connect(channels)
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
class SE_Res2Block(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
scale,
|
||||||
|
se_bottleneck_dim,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.Conv1dReluBn1 = Conv1dReluBn(
|
||||||
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
self.Res2Conv1dReluBn = Res2Conv1dReluBn(
|
||||||
|
out_channels, kernel_size, stride, padding, dilation, scale=scale
|
||||||
|
)
|
||||||
|
self.Conv1dReluBn2 = Conv1dReluBn(
|
||||||
|
out_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||||
|
)
|
||||||
|
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
|
||||||
|
|
||||||
|
self.shortcut = None
|
||||||
|
if in_channels != out_channels:
|
||||||
|
self.shortcut = nn.Conv1d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
if self.shortcut:
|
||||||
|
residual = self.shortcut(x)
|
||||||
|
|
||||||
|
x = self.Conv1dReluBn1(x)
|
||||||
|
x = self.Res2Conv1dReluBn(x)
|
||||||
|
x = self.Conv1dReluBn2(x)
|
||||||
|
x = self.SE_Connect(x)
|
||||||
|
|
||||||
|
return x + residual
|
||||||
|
|
||||||
|
|
||||||
|
""" Attentive weighted mean and standard deviation pooling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class AttentiveStatsPool(nn.Module):
|
||||||
|
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
|
||||||
|
super().__init__()
|
||||||
|
self.global_context_att = global_context_att
|
||||||
|
|
||||||
|
# Use Conv1d with stride == 1 rather than Linear,
|
||||||
|
# then we don't need to transpose inputs.
|
||||||
|
if global_context_att:
|
||||||
|
self.linear1 = nn.Conv1d(
|
||||||
|
in_dim * 3, attention_channels, kernel_size=1
|
||||||
|
) # equals W and b in the paper
|
||||||
|
else:
|
||||||
|
self.linear1 = nn.Conv1d(
|
||||||
|
in_dim, attention_channels, kernel_size=1
|
||||||
|
) # equals W and b in the paper
|
||||||
|
self.linear2 = nn.Conv1d(
|
||||||
|
attention_channels, in_dim, kernel_size=1
|
||||||
|
) # equals V and k in the paper
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
if self.global_context_att:
|
||||||
|
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
||||||
|
context_std = torch.sqrt(
|
||||||
|
torch.var(x, dim=-1, keepdim=True) + 1e-10
|
||||||
|
).expand_as(x)
|
||||||
|
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
||||||
|
else:
|
||||||
|
x_in = x
|
||||||
|
|
||||||
|
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
||||||
|
alpha = torch.tanh(self.linear1(x_in))
|
||||||
|
# alpha = F.relu(self.linear1(x_in))
|
||||||
|
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||||||
|
mean = torch.sum(alpha * x, dim=2)
|
||||||
|
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
|
||||||
|
std = torch.sqrt(residuals.clamp(min=1e-9))
|
||||||
|
return torch.cat([mean, std], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class ECAPA_TDNN_WAVLLM(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feat_dim=80,
|
||||||
|
channels=512,
|
||||||
|
emb_dim=192,
|
||||||
|
global_context_att=False,
|
||||||
|
sr=16000,
|
||||||
|
ssl_model_path=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.sr = sr
|
||||||
|
|
||||||
|
if ssl_model_path is None:
|
||||||
|
self.feature_extract = torch.hub.load("s3prl/s3prl", "wavlm_large")
|
||||||
|
else:
|
||||||
|
self.feature_extract = torch.hub.load(
|
||||||
|
os.path.dirname(ssl_model_path),
|
||||||
|
"wavlm_local",
|
||||||
|
source="local",
|
||||||
|
ckpt=ssl_model_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
||||||
|
self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
|
||||||
|
):
|
||||||
|
self.feature_extract.model.encoder.layers[
|
||||||
|
23
|
||||||
|
].self_attn.fp32_attention = False
|
||||||
|
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
||||||
|
self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
|
||||||
|
):
|
||||||
|
self.feature_extract.model.encoder.layers[
|
||||||
|
11
|
||||||
|
].self_attn.fp32_attention = False
|
||||||
|
|
||||||
|
self.feat_num = self.get_feat_num()
|
||||||
|
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
||||||
|
|
||||||
|
self.instance_norm = nn.InstanceNorm1d(feat_dim)
|
||||||
|
# self.channels = [channels] * 4 + [channels * 3]
|
||||||
|
self.channels = [channels] * 4 + [1536]
|
||||||
|
|
||||||
|
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
|
||||||
|
self.layer2 = SE_Res2Block(
|
||||||
|
self.channels[0],
|
||||||
|
self.channels[1],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=2,
|
||||||
|
dilation=2,
|
||||||
|
scale=8,
|
||||||
|
se_bottleneck_dim=128,
|
||||||
|
)
|
||||||
|
self.layer3 = SE_Res2Block(
|
||||||
|
self.channels[1],
|
||||||
|
self.channels[2],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=3,
|
||||||
|
dilation=3,
|
||||||
|
scale=8,
|
||||||
|
se_bottleneck_dim=128,
|
||||||
|
)
|
||||||
|
self.layer4 = SE_Res2Block(
|
||||||
|
self.channels[2],
|
||||||
|
self.channels[3],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=4,
|
||||||
|
dilation=4,
|
||||||
|
scale=8,
|
||||||
|
se_bottleneck_dim=128,
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
||||||
|
cat_channels = channels * 3
|
||||||
|
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
||||||
|
self.pooling = AttentiveStatsPool(
|
||||||
|
self.channels[-1],
|
||||||
|
attention_channels=128,
|
||||||
|
global_context_att=global_context_att,
|
||||||
|
)
|
||||||
|
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
||||||
|
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
||||||
|
|
||||||
|
def get_feat_num(self):
|
||||||
|
self.feature_extract.eval()
|
||||||
|
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
|
||||||
|
with torch.no_grad():
|
||||||
|
features = self.feature_extract(wav)
|
||||||
|
select_feature = features["hidden_states"]
|
||||||
|
if isinstance(select_feature, (list, tuple)):
|
||||||
|
return len(select_feature)
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def get_feat(self, x):
|
||||||
|
with torch.no_grad():
|
||||||
|
x = self.feature_extract([sample for sample in x])
|
||||||
|
|
||||||
|
x = x["hidden_states"]
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
x = torch.stack(x, dim=0)
|
||||||
|
else:
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
norm_weights = (
|
||||||
|
F.softmax(self.feature_weight, dim=-1)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
)
|
||||||
|
x = (norm_weights * x).sum(dim=0)
|
||||||
|
x = torch.transpose(x, 1, 2) + 1e-6
|
||||||
|
|
||||||
|
x = self.instance_norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.get_feat(x)
|
||||||
|
|
||||||
|
out1 = self.layer1(x)
|
||||||
|
out2 = self.layer2(out1)
|
||||||
|
out3 = self.layer3(out2)
|
||||||
|
out4 = self.layer4(out3)
|
||||||
|
|
||||||
|
out = torch.cat([out2, out3, out4], dim=1)
|
||||||
|
out = F.relu(self.conv(out))
|
||||||
|
out = self.bn(self.pooling(out))
|
||||||
|
out = self.linear(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
SIM = SpeakerSimilarity(
|
||||||
|
sv_model_path=args.sv_model_path, ssl_model_path=args.ssl_model_path
|
||||||
|
)
|
||||||
|
score = SIM.score(args.eval_path, args.test_list)
|
||||||
|
logging.info(f"SIM-o score: {score:.3f}")
|
294
egs/zipvoice/local/evaluate_utmos.py
Normal file
294
egs/zipvoice/local/evaluate_utmos.py
Normal file
@ -0,0 +1,294 @@
|
|||||||
|
"""
|
||||||
|
Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system
|
||||||
|
adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo
|
||||||
|
|
||||||
|
# Download model checkpoints
|
||||||
|
wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt
|
||||||
|
wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import fairseq
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--wav-path", type=str, help="path of the evaluated speech directory"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--utmos-model-path",
|
||||||
|
type=str,
|
||||||
|
default="model/huggingface/utmos/utmos.pt",
|
||||||
|
help="path of the UTMOS model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ssl-model-path",
|
||||||
|
type=str,
|
||||||
|
default="model/huggingface/utmos/wav2vec_small.pt",
|
||||||
|
help="path of the wav2vec SSL model",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class UTMOSScore:
|
||||||
|
"""Predicting score for each audio clip."""
|
||||||
|
|
||||||
|
def __init__(self, utmos_model_path, ssl_model_path):
|
||||||
|
self.sample_rate = 16000
|
||||||
|
self.device = (
|
||||||
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
)
|
||||||
|
self.model = (
|
||||||
|
BaselineLightningModule.load_from_checkpoint(
|
||||||
|
utmos_model_path, ssl_model_path=ssl_model_path
|
||||||
|
)
|
||||||
|
.eval()
|
||||||
|
.to(self.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def score(self, wavs: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
wavs: waveforms to be evaluated. When len(wavs) == 1 or 2,
|
||||||
|
the model processes the input as a single audio clip. The model
|
||||||
|
performs batch processing when len(wavs) == 3.
|
||||||
|
"""
|
||||||
|
if len(wavs.shape) == 1:
|
||||||
|
out_wavs = wavs.unsqueeze(0).unsqueeze(0)
|
||||||
|
elif len(wavs.shape) == 2:
|
||||||
|
out_wavs = wavs.unsqueeze(0)
|
||||||
|
elif len(wavs.shape) == 3:
|
||||||
|
out_wavs = wavs
|
||||||
|
else:
|
||||||
|
raise ValueError("Dimension of input tensor needs to be <= 3.")
|
||||||
|
bs = out_wavs.shape[0]
|
||||||
|
batch = {
|
||||||
|
"wav": out_wavs,
|
||||||
|
"domains": torch.zeros(bs, dtype=torch.int).to(self.device),
|
||||||
|
"judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288,
|
||||||
|
}
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self.model(batch)
|
||||||
|
|
||||||
|
return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3
|
||||||
|
|
||||||
|
def score_dir(self, dir, dtype="float32"):
|
||||||
|
def _load_speech_task(fname, sample_rate):
|
||||||
|
|
||||||
|
wav_data, sr = sf.read(fname, dtype=dtype)
|
||||||
|
if sr != sample_rate:
|
||||||
|
wav_data = librosa.resample(
|
||||||
|
wav_data, orig_sr=sr, target_sr=self.sample_rate
|
||||||
|
)
|
||||||
|
wav_data = torch.from_numpy(wav_data)
|
||||||
|
|
||||||
|
return wav_data
|
||||||
|
|
||||||
|
score_lst = []
|
||||||
|
for fname in tqdm(os.listdir(dir)):
|
||||||
|
speech = _load_speech_task(os.path.join(dir, fname), self.sample_rate)
|
||||||
|
speech = speech.to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
score = self.score(speech)
|
||||||
|
score_lst.append(score.item())
|
||||||
|
return np.mean(score_lst)
|
||||||
|
|
||||||
|
|
||||||
|
def load_ssl_model(ckpt_path="wav2vec_small.pt"):
|
||||||
|
SSL_OUT_DIM = 768
|
||||||
|
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||||
|
[ckpt_path]
|
||||||
|
)
|
||||||
|
ssl_model = model[0]
|
||||||
|
ssl_model.remove_pretraining_modules()
|
||||||
|
return SSL_model(ssl_model, SSL_OUT_DIM)
|
||||||
|
|
||||||
|
|
||||||
|
class BaselineLightningModule(pl.LightningModule):
|
||||||
|
def __init__(self, ssl_model_path):
|
||||||
|
super().__init__()
|
||||||
|
self.construct_model(ssl_model_path)
|
||||||
|
self.save_hyperparameters()
|
||||||
|
|
||||||
|
def construct_model(self, ssl_model_path):
|
||||||
|
self.feature_extractors = nn.ModuleList(
|
||||||
|
[
|
||||||
|
load_ssl_model(ckpt_path=ssl_model_path),
|
||||||
|
DomainEmbedding(3, 128),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
output_dim = sum(
|
||||||
|
[
|
||||||
|
feature_extractor.get_output_dim()
|
||||||
|
for feature_extractor in self.feature_extractors
|
||||||
|
]
|
||||||
|
)
|
||||||
|
output_layers = [
|
||||||
|
LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim)
|
||||||
|
]
|
||||||
|
output_dim = output_layers[-1].get_output_dim()
|
||||||
|
output_layers.append(
|
||||||
|
Projection(
|
||||||
|
hidden_dim=2048,
|
||||||
|
activation=torch.nn.ReLU(),
|
||||||
|
range_clipping=False,
|
||||||
|
input_dim=output_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.output_layers = nn.ModuleList(output_layers)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
outputs = {}
|
||||||
|
for feature_extractor in self.feature_extractors:
|
||||||
|
outputs.update(feature_extractor(inputs))
|
||||||
|
x = outputs
|
||||||
|
for output_layer in self.output_layers:
|
||||||
|
x = output_layer(x, inputs)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SSL_model(nn.Module):
|
||||||
|
def __init__(self, ssl_model, ssl_out_dim) -> None:
|
||||||
|
super(SSL_model, self).__init__()
|
||||||
|
self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
wav = batch["wav"]
|
||||||
|
wav = wav.squeeze(1) # [batches, wav_len]
|
||||||
|
res = self.ssl_model(wav, mask=False, features_only=True)
|
||||||
|
x = res["x"]
|
||||||
|
return {"ssl-feature": x}
|
||||||
|
|
||||||
|
def get_output_dim(self):
|
||||||
|
return self.ssl_out_dim
|
||||||
|
|
||||||
|
|
||||||
|
class DomainEmbedding(nn.Module):
|
||||||
|
def __init__(self, n_domains, domain_dim) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.embedding = nn.Embedding(n_domains, domain_dim)
|
||||||
|
self.output_dim = domain_dim
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
return {"domain-feature": self.embedding(batch["domains"])}
|
||||||
|
|
||||||
|
def get_output_dim(self):
|
||||||
|
return self.output_dim
|
||||||
|
|
||||||
|
|
||||||
|
class LDConditioner(nn.Module):
|
||||||
|
"""
|
||||||
|
Conditions ssl output by listener embedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_dim, judge_dim, num_judges=None):
|
||||||
|
super().__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.judge_dim = judge_dim
|
||||||
|
self.num_judges = num_judges
|
||||||
|
assert num_judges != None
|
||||||
|
self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
|
||||||
|
# concat [self.output_layer, phoneme features]
|
||||||
|
|
||||||
|
self.decoder_rnn = nn.LSTM(
|
||||||
|
input_size=self.input_dim + self.judge_dim,
|
||||||
|
hidden_size=512,
|
||||||
|
num_layers=1,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=True,
|
||||||
|
) # linear?
|
||||||
|
self.out_dim = self.decoder_rnn.hidden_size * 2
|
||||||
|
|
||||||
|
def get_output_dim(self):
|
||||||
|
return self.out_dim
|
||||||
|
|
||||||
|
def forward(self, x, batch):
|
||||||
|
judge_ids = batch["judge_id"]
|
||||||
|
if "phoneme-feature" in x.keys():
|
||||||
|
concatenated_feature = torch.cat(
|
||||||
|
(
|
||||||
|
x["ssl-feature"],
|
||||||
|
x["phoneme-feature"]
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(-1, x["ssl-feature"].size(1), -1),
|
||||||
|
),
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
concatenated_feature = x["ssl-feature"]
|
||||||
|
if "domain-feature" in x.keys():
|
||||||
|
concatenated_feature = torch.cat(
|
||||||
|
(
|
||||||
|
concatenated_feature,
|
||||||
|
x["domain-feature"]
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(-1, concatenated_feature.size(1), -1),
|
||||||
|
),
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
if judge_ids != None:
|
||||||
|
concatenated_feature = torch.cat(
|
||||||
|
(
|
||||||
|
concatenated_feature,
|
||||||
|
self.judge_embedding(judge_ids)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(-1, concatenated_feature.size(1), -1),
|
||||||
|
),
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
|
||||||
|
return decoder_output
|
||||||
|
|
||||||
|
|
||||||
|
class Projection(nn.Module):
|
||||||
|
def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
|
||||||
|
super(Projection, self).__init__()
|
||||||
|
self.range_clipping = range_clipping
|
||||||
|
output_dim = 1
|
||||||
|
if range_clipping:
|
||||||
|
self.proj = nn.Tanh()
|
||||||
|
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, hidden_dim),
|
||||||
|
activation,
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(hidden_dim, output_dim),
|
||||||
|
)
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
def forward(self, x, batch):
|
||||||
|
output = self.net(x)
|
||||||
|
|
||||||
|
# range clipping
|
||||||
|
if self.range_clipping:
|
||||||
|
return self.proj(output) * 2.0 + 3
|
||||||
|
else:
|
||||||
|
return output
|
||||||
|
|
||||||
|
def get_output_dim(self):
|
||||||
|
return self.output_dim
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
UTMOS = UTMOSScore(
|
||||||
|
utmos_model_path=args.utmos_model_path, ssl_model_path=args.ssl_model_path
|
||||||
|
)
|
||||||
|
score = UTMOS.score_dir(args.wav_path)
|
||||||
|
logging.info(f"UTMOS score: {score:.2f}")
|
172
egs/zipvoice/local/evaluate_wer_hubert.py
Normal file
172
egs/zipvoice/local/evaluate_wer_hubert.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
"""
|
||||||
|
Calculate WER with Hubert models.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
from jiwer import compute_measures
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="path of the output file of WER information",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="path of the local hubert model, e.g., model/huggingface/hubert-large-ls960-ft",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-list",
|
||||||
|
type=str,
|
||||||
|
default="test.tsv",
|
||||||
|
help="path of the transcript tsv file, where the first column "
|
||||||
|
"is the wav name and the last column is the transcript",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size", type=int, default=16, help="decoding batch size"
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def post_process(text: str):
|
||||||
|
text = text.replace("‘", "'")
|
||||||
|
text = text.replace("’", "'")
|
||||||
|
text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
|
||||||
|
text = re.sub(r"\s+", " ", text)
|
||||||
|
text = text.strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def process_one(hypo, truth):
|
||||||
|
truth = post_process(truth)
|
||||||
|
hypo = post_process(hypo)
|
||||||
|
|
||||||
|
measures = compute_measures(truth, hypo)
|
||||||
|
word_num = len(truth.split(" "))
|
||||||
|
wer = measures["wer"]
|
||||||
|
subs = measures["substitutions"]
|
||||||
|
dele = measures["deletions"]
|
||||||
|
inse = measures["insertions"]
|
||||||
|
return (truth, hypo, wer, subs, dele, inse, word_num)
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechEvalDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, wav_path: str, test_list: str):
|
||||||
|
super().__init__()
|
||||||
|
self.wav_name = []
|
||||||
|
self.wav_paths = []
|
||||||
|
self.transcripts = []
|
||||||
|
with Path(test_list).open("r", encoding="utf8") as f:
|
||||||
|
meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
|
||||||
|
for item in meta:
|
||||||
|
self.wav_name.append(item[0])
|
||||||
|
self.wav_paths.append(Path(wav_path, item[0] + ".wav"))
|
||||||
|
self.transcripts.append(item[-1])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.wav_paths)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
wav, sampling_rate = sf.read(self.wav_paths[index])
|
||||||
|
item = {
|
||||||
|
"array": librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000),
|
||||||
|
"sampling_rate": 16000,
|
||||||
|
"reference": self.transcripts[index],
|
||||||
|
"wav_name": self.wav_name[index],
|
||||||
|
}
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
def main(test_list, wav_path, model_path, decode_path, batch_size, device):
|
||||||
|
|
||||||
|
if model_path is not None:
|
||||||
|
pipe = pipeline(
|
||||||
|
"automatic-speech-recognition",
|
||||||
|
model=model_path,
|
||||||
|
device=device,
|
||||||
|
tokenizer=model_path,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pipe = pipeline(
|
||||||
|
"automatic-speech-recognition",
|
||||||
|
model="facebook/hubert-large-ls960-ft",
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = SpeechEvalDataset(wav_path, test_list)
|
||||||
|
|
||||||
|
bar = tqdm(
|
||||||
|
pipe(
|
||||||
|
dataset,
|
||||||
|
generate_kwargs={"language": "english", "task": "transcribe"},
|
||||||
|
batch_size=batch_size,
|
||||||
|
),
|
||||||
|
total=len(dataset),
|
||||||
|
)
|
||||||
|
|
||||||
|
wers = []
|
||||||
|
inses = []
|
||||||
|
deles = []
|
||||||
|
subses = []
|
||||||
|
word_nums = 0
|
||||||
|
if decode_path:
|
||||||
|
decode_dir = os.path.dirname(decode_path)
|
||||||
|
if not os.path.exists(decode_dir):
|
||||||
|
os.makedirs(decode_dir)
|
||||||
|
fout = open(decode_path, "w")
|
||||||
|
for out in bar:
|
||||||
|
wav_name = out["wav_name"][0]
|
||||||
|
transcription = post_process(out["text"].strip())
|
||||||
|
text_ref = post_process(out["reference"][0].strip())
|
||||||
|
truth, hypo, wer, subs, dele, inse, word_num = process_one(
|
||||||
|
transcription, text_ref
|
||||||
|
)
|
||||||
|
if decode_path:
|
||||||
|
fout.write(f"{wav_name}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
|
||||||
|
wers.append(float(wer))
|
||||||
|
inses.append(float(inse))
|
||||||
|
deles.append(float(dele))
|
||||||
|
subses.append(float(subs))
|
||||||
|
word_nums += word_num
|
||||||
|
|
||||||
|
wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
|
||||||
|
subs = round(np.mean(subses) * 100, 3)
|
||||||
|
dele = round(np.mean(deles) * 100, 3)
|
||||||
|
inse = round(np.mean(inses) * 100, 3)
|
||||||
|
print(f"WER: {wer}%\n")
|
||||||
|
if decode_path:
|
||||||
|
fout.write(f"WER: {wer}%\n")
|
||||||
|
fout.flush()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
main(
|
||||||
|
args.test_list,
|
||||||
|
args.wav_path,
|
||||||
|
args.model_path,
|
||||||
|
args.decode_path,
|
||||||
|
args.batch_size,
|
||||||
|
device,
|
||||||
|
)
|
181
egs/zipvoice/local/evaluate_wer_seedtts.py
Normal file
181
egs/zipvoice/local/evaluate_wer_seedtts.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
"""
|
||||||
|
Calculate WER with Whisper-large-v3 or Paraformer models,
|
||||||
|
following Seed-TTS https://github.com/BytedanceSpeech/seed-tts-eval
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import string
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import scipy
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
import zhconv
|
||||||
|
from funasr import AutoModel
|
||||||
|
from jiwer import compute_measures
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
||||||
|
from zhon.hanzi import punctuation
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="path of the output file of WER information",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="path of the local whisper and paraformer model, "
|
||||||
|
"e.g., whisper: model/huggingface/whisper-large-v3/, "
|
||||||
|
"paraformer: model/huggingface/paraformer-zh/",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-list",
|
||||||
|
type=str,
|
||||||
|
default="test.tsv",
|
||||||
|
help="path of the transcript tsv file, where the first column "
|
||||||
|
"is the wav name and the last column is the transcript",
|
||||||
|
)
|
||||||
|
parser.add_argument("--lang", type=str, help="decoded language, zh or en")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def load_en_model(model_path):
|
||||||
|
if model_path is None:
|
||||||
|
model_path = "openai/whisper-large-v3"
|
||||||
|
processor = WhisperProcessor.from_pretrained(model_path)
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained(model_path)
|
||||||
|
return processor, model
|
||||||
|
|
||||||
|
|
||||||
|
def load_zh_model(model_path):
|
||||||
|
if model_path is None:
|
||||||
|
model_path = "paraformer-zh"
|
||||||
|
model = AutoModel(model=model_path)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def process_one(hypo, truth, lang):
|
||||||
|
punctuation_all = punctuation + string.punctuation
|
||||||
|
for x in punctuation_all:
|
||||||
|
if x == "'":
|
||||||
|
continue
|
||||||
|
truth = truth.replace(x, "")
|
||||||
|
hypo = hypo.replace(x, "")
|
||||||
|
|
||||||
|
truth = truth.replace(" ", " ")
|
||||||
|
hypo = hypo.replace(" ", " ")
|
||||||
|
|
||||||
|
if lang == "zh":
|
||||||
|
truth = " ".join([x for x in truth])
|
||||||
|
hypo = " ".join([x for x in hypo])
|
||||||
|
elif lang == "en":
|
||||||
|
truth = truth.lower()
|
||||||
|
hypo = hypo.lower()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
measures = compute_measures(truth, hypo)
|
||||||
|
word_num = len(truth.split(" "))
|
||||||
|
wer = measures["wer"]
|
||||||
|
subs = measures["substitutions"]
|
||||||
|
dele = measures["deletions"]
|
||||||
|
inse = measures["insertions"]
|
||||||
|
return (truth, hypo, wer, subs, dele, inse, word_num)
|
||||||
|
|
||||||
|
|
||||||
|
def main(test_list, wav_path, model_path, decode_path, lang, device):
|
||||||
|
if lang == "en":
|
||||||
|
processor, model = load_en_model(model_path)
|
||||||
|
model.to(device)
|
||||||
|
elif lang == "zh":
|
||||||
|
model = load_zh_model(model_path)
|
||||||
|
params = []
|
||||||
|
for line in open(test_list).readlines():
|
||||||
|
line = line.strip()
|
||||||
|
items = line.split("\t")
|
||||||
|
wav_name, text_ref = items[0], items[-1]
|
||||||
|
file_path = os.path.join(wav_path, wav_name + ".wav")
|
||||||
|
assert os.path.exists(file_path), f"{file_path}"
|
||||||
|
|
||||||
|
params.append((file_path, text_ref))
|
||||||
|
wers = []
|
||||||
|
inses = []
|
||||||
|
deles = []
|
||||||
|
subses = []
|
||||||
|
word_nums = 0
|
||||||
|
if decode_path:
|
||||||
|
decode_dir = os.path.dirname(decode_path)
|
||||||
|
if not os.path.exists(decode_dir):
|
||||||
|
os.makedirs(decode_dir)
|
||||||
|
fout = open(decode_path, "w")
|
||||||
|
for wav_path, text_ref in tqdm(params):
|
||||||
|
if lang == "en":
|
||||||
|
wav, sr = sf.read(wav_path)
|
||||||
|
if sr != 16000:
|
||||||
|
wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr))
|
||||||
|
input_features = processor(
|
||||||
|
wav, sampling_rate=16000, return_tensors="pt"
|
||||||
|
).input_features
|
||||||
|
input_features = input_features.to(device)
|
||||||
|
forced_decoder_ids = processor.get_decoder_prompt_ids(
|
||||||
|
language="english", task="transcribe"
|
||||||
|
)
|
||||||
|
predicted_ids = model.generate(
|
||||||
|
input_features, forced_decoder_ids=forced_decoder_ids
|
||||||
|
)
|
||||||
|
transcription = processor.batch_decode(
|
||||||
|
predicted_ids, skip_special_tokens=True
|
||||||
|
)[0]
|
||||||
|
elif lang == "zh":
|
||||||
|
res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True)
|
||||||
|
transcription = res[0]["text"]
|
||||||
|
transcription = zhconv.convert(transcription, "zh-cn")
|
||||||
|
|
||||||
|
truth, hypo, wer, subs, dele, inse, word_num = process_one(
|
||||||
|
transcription, text_ref, lang
|
||||||
|
)
|
||||||
|
if decode_path:
|
||||||
|
fout.write(f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
|
||||||
|
wers.append(float(wer))
|
||||||
|
inses.append(float(inse))
|
||||||
|
deles.append(float(dele))
|
||||||
|
subses.append(float(subs))
|
||||||
|
word_nums += word_num
|
||||||
|
|
||||||
|
wer_avg = round(np.mean(wers) * 100, 3)
|
||||||
|
wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
|
||||||
|
subs = round(np.mean(subses) * 100, 3)
|
||||||
|
dele = round(np.mean(deles) * 100, 3)
|
||||||
|
inse = round(np.mean(inses) * 100, 3)
|
||||||
|
print(f"Seed-TTS WER: {wer_avg}%\n")
|
||||||
|
print(f"WER: {wer}%\n")
|
||||||
|
if decode_path:
|
||||||
|
fout.write(f"SeedTTS WER: {wer_avg}%\n")
|
||||||
|
fout.write(f"WER: {wer}%\n")
|
||||||
|
fout.flush()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
main(
|
||||||
|
args.test_list,
|
||||||
|
args.wav_path,
|
||||||
|
args.model_path,
|
||||||
|
args.decode_path,
|
||||||
|
args.lang,
|
||||||
|
device,
|
||||||
|
)
|
135
egs/zipvoice/local/feature.py
Normal file
135
egs/zipvoice/local/feature.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchaudio
|
||||||
|
from lhotse.features.base import FeatureExtractor, register_extractor
|
||||||
|
from lhotse.utils import Seconds, compute_num_frames
|
||||||
|
|
||||||
|
|
||||||
|
class MelSpectrogramFeatures(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sampling_rate=24000,
|
||||||
|
n_mels=100,
|
||||||
|
n_fft=1024,
|
||||||
|
hop_length=256,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
||||||
|
sample_rate=sampling_rate,
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
n_mels=n_mels,
|
||||||
|
center=True,
|
||||||
|
power=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, inp):
|
||||||
|
assert len(inp.shape) == 2
|
||||||
|
|
||||||
|
mel = self.mel_spec(inp)
|
||||||
|
logmel = mel.clamp(min=1e-7).log()
|
||||||
|
return logmel
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TorchAudioFbankConfig:
|
||||||
|
sampling_rate: int
|
||||||
|
n_mels: int
|
||||||
|
n_fft: int
|
||||||
|
hop_length: int
|
||||||
|
|
||||||
|
|
||||||
|
@register_extractor
|
||||||
|
class TorchAudioFbank(FeatureExtractor):
|
||||||
|
|
||||||
|
name = "TorchAudioFbank"
|
||||||
|
config_type = TorchAudioFbankConfig
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
|
def _feature_fn(self, sample):
|
||||||
|
fbank = MelSpectrogramFeatures(
|
||||||
|
sampling_rate=self.config.sampling_rate,
|
||||||
|
n_mels=self.config.n_mels,
|
||||||
|
n_fft=self.config.n_fft,
|
||||||
|
hop_length=self.config.hop_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fbank(sample)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Union[str, torch.device]:
|
||||||
|
return self.config.device
|
||||||
|
|
||||||
|
def feature_dim(self, sampling_rate: int) -> int:
|
||||||
|
return self.config.n_mels
|
||||||
|
|
||||||
|
def extract(
|
||||||
|
self,
|
||||||
|
samples: Union[np.ndarray, torch.Tensor],
|
||||||
|
sampling_rate: int,
|
||||||
|
) -> Union[np.ndarray, torch.Tensor]:
|
||||||
|
# Check for sampling rate compatibility.
|
||||||
|
expected_sr = self.config.sampling_rate
|
||||||
|
assert sampling_rate == expected_sr, (
|
||||||
|
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
||||||
|
f"got {sampling_rate}"
|
||||||
|
)
|
||||||
|
is_numpy = False
|
||||||
|
if not isinstance(samples, torch.Tensor):
|
||||||
|
samples = torch.from_numpy(samples)
|
||||||
|
is_numpy = True
|
||||||
|
|
||||||
|
if len(samples.shape) == 1:
|
||||||
|
samples = samples.unsqueeze(0)
|
||||||
|
assert samples.ndim == 2, samples.shape
|
||||||
|
assert samples.shape[0] == 1, samples.shape
|
||||||
|
|
||||||
|
mel = self._feature_fn(samples).squeeze().t()
|
||||||
|
|
||||||
|
assert mel.ndim == 2, mel.shape
|
||||||
|
assert mel.shape[1] == self.config.n_mels, mel.shape
|
||||||
|
|
||||||
|
num_frames = compute_num_frames(
|
||||||
|
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
if mel.shape[0] > num_frames:
|
||||||
|
mel = mel[:num_frames]
|
||||||
|
elif mel.shape[0] < num_frames:
|
||||||
|
mel = mel.unsqueeze(0)
|
||||||
|
mel = torch.nn.functional.pad(
|
||||||
|
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
if is_numpy:
|
||||||
|
return mel.cpu().numpy()
|
||||||
|
else:
|
||||||
|
return mel
|
||||||
|
|
||||||
|
@property
|
||||||
|
def frame_shift(self) -> Seconds:
|
||||||
|
return self.config.hop_length / self.config.sampling_rate
|
88
egs/zipvoice/local/prepare_libritts.sh
Executable file
88
egs/zipvoice/local/prepare_libritts.sh
Executable file
@ -0,0 +1,88 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||||
|
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
stop_stage=5
|
||||||
|
sampling_rate=24000
|
||||||
|
nj=32
|
||||||
|
|
||||||
|
dl_dir=$PWD/download
|
||||||
|
|
||||||
|
# All files generated by this script are saved in "data".
|
||||||
|
# You can safely remove "data" and rerun this script to regenerate it.
|
||||||
|
mkdir -p data
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
log "dl_dir: $dl_dir"
|
||||||
|
|
||||||
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
|
log "Stage 0: Download data"
|
||||||
|
|
||||||
|
# If you have pre-downloaded it to /path/to/LibriTTS,
|
||||||
|
# you can create a symlink
|
||||||
|
#
|
||||||
|
# ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS
|
||||||
|
#
|
||||||
|
if [ ! -d $dl_dir/LibriTTS ]; then
|
||||||
|
lhotse download libritts $dl_dir
|
||||||
|
fi
|
||||||
|
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
log "Stage 1: Prepare LibriTTS manifest"
|
||||||
|
# We assume that you have downloaded the LibriTTS corpus
|
||||||
|
# to $dl_dir/LibriTTS
|
||||||
|
mkdir -p data/manifests_libritts
|
||||||
|
if [ ! -e data/manifests_libritts/.libritts.done ]; then
|
||||||
|
lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests_libritts
|
||||||
|
touch data/manifests_libritts/.libritts.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
|
log "Stage 2: Compute Fbank for LibriTTS"
|
||||||
|
mkdir -p data/fbank
|
||||||
|
if [ ! -e data/fbank_libritts/.libritts.done ]; then
|
||||||
|
./local/compute_fbank_libritts.py --sampling-rate $sampling_rate
|
||||||
|
touch data/fbank_libritts/.libritts.done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Here we shuffle and combine the train-clean-100, train-clean-360 and
|
||||||
|
# train-other-500 together to form the training set.
|
||||||
|
if [ ! -f data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz ]; then
|
||||||
|
cat <(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-100.jsonl.gz) \
|
||||||
|
<(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-360.jsonl.gz) \
|
||||||
|
<(gunzip -c data/fbank_libritts/libritts_cuts_train-other-500.jsonl.gz) | \
|
||||||
|
shuf | gzip -c > data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f data/fbank_libritts/libritts_cuts_train-clean-460.jsonl.gz ]; then
|
||||||
|
cat <(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-100.jsonl.gz) \
|
||||||
|
<(gunzip -c data/fbank_libritts/libritts_cuts_train-clean-360.jsonl.gz) | \
|
||||||
|
shuf | gzip -c > data/fbank_libritts/libritts_cuts_train-clean-460.jsonl.gz
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -e data/fbank_libritts/.libritts-validated.done ]; then
|
||||||
|
log "Validating data/fbank for LibriTTS"
|
||||||
|
./local/validate_manifest.py \
|
||||||
|
data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz
|
||||||
|
touch data/fbank_libritts/.libritts-validated.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
log "Stage 4: Generate token file"
|
||||||
|
if [ ! -e data/tokens_libritts.txt ]; then
|
||||||
|
./local/prepare_token_file_libritts.py --tokens data/tokens_libritts.txt
|
||||||
|
fi
|
||||||
|
fi
|
90
egs/zipvoice/local/prepare_token_file_emilia.py
Normal file
90
egs/zipvoice/local/prepare_token_file_emilia.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao,
|
||||||
|
# Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file generates the file that maps tokens to IDs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from piper_phonemize import get_espeak_map
|
||||||
|
from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/tokens_emilia.txt"),
|
||||||
|
help="Path to the dict that maps the text tokens to IDs",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--pinyin",
|
||||||
|
type=Path,
|
||||||
|
default=Path("local/pinyin.txt"),
|
||||||
|
help="Path to the all unique pinyin",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def get_pinyin_tokens(pinyin: Path) -> List[str]:
|
||||||
|
phones = set()
|
||||||
|
with open(pinyin, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
x = line.strip()
|
||||||
|
initial = to_initials(x, strict=False)
|
||||||
|
# don't want to share tokens with espeak tokens, so use tone3 style
|
||||||
|
finals = to_finals_tone3(x, strict=False, neutral_tone_with_five=True)
|
||||||
|
if initial != "":
|
||||||
|
# don't want to share tokens with espeak tokens, so add a '0' after each initial
|
||||||
|
phones.add(initial + "0")
|
||||||
|
if finals != "":
|
||||||
|
phones.add(finals)
|
||||||
|
return sorted(phones)
|
||||||
|
|
||||||
|
|
||||||
|
def get_token2id(args):
|
||||||
|
"""Get a dict that maps token to IDs, and save it to the given filename."""
|
||||||
|
all_tokens = get_espeak_map() # token: [token_id]
|
||||||
|
all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()}
|
||||||
|
# sort by token_id
|
||||||
|
all_tokens = sorted(all_tokens.items(), key=lambda x: x[1])
|
||||||
|
|
||||||
|
all_pinyin = get_pinyin_tokens(args.pinyin)
|
||||||
|
with open(args.tokens, "w", encoding="utf-8") as f:
|
||||||
|
for token, token_id in all_tokens:
|
||||||
|
f.write(f"{token} {token_id}\n")
|
||||||
|
num_espeak_tokens = len(all_tokens)
|
||||||
|
for i, pinyin in enumerate(all_pinyin):
|
||||||
|
f.write(f"{pinyin} {num_espeak_tokens + i}\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
args = get_args()
|
||||||
|
get_token2id(args)
|
31
egs/zipvoice/local/prepare_token_file_libritts.py
Normal file
31
egs/zipvoice/local/prepare_token_file_libritts.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import re
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
from lhotse import load_manifest_lazy
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_tokens(manifest_file, token_file):
|
||||||
|
counter = Counter()
|
||||||
|
manifest = load_manifest_lazy(manifest_file)
|
||||||
|
for cut in manifest:
|
||||||
|
line = re.sub(r"\s+", " ", cut.supervisions[0].text)
|
||||||
|
counter.update(line)
|
||||||
|
|
||||||
|
unique_chars = set(counter.keys())
|
||||||
|
|
||||||
|
if "_" in unique_chars:
|
||||||
|
unique_chars.remove("_")
|
||||||
|
|
||||||
|
sorted_chars = sorted(unique_chars, key=lambda char: counter[char], reverse=True)
|
||||||
|
|
||||||
|
result = ["_"] + sorted_chars
|
||||||
|
|
||||||
|
with open(token_file, "w", encoding="utf-8") as file:
|
||||||
|
for index, char in enumerate(result):
|
||||||
|
file.write(f"{char} {index}\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
manifest_file = "data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz"
|
||||||
|
output_token_file = "data/tokens_libritts.txt"
|
||||||
|
prepare_tokens(manifest_file, output_token_file)
|
192
egs/zipvoice/local/prepare_tokens_emilia.py
Normal file
192
egs/zipvoice/local/prepare_tokens_emilia.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao,
|
||||||
|
# Zengrui Jin,
|
||||||
|
# Wei Kang)
|
||||||
|
# 2024 Tsinghua University (authors: Zengrui Jin,)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file reads the texts in given manifest and save the new cuts with phoneme tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from concurrent.futures import ProcessPoolExecutor as Pool
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import jieba
|
||||||
|
from lhotse import load_manifest_lazy
|
||||||
|
from tokenizer import Tokenizer, is_alphabet, is_chinese, is_hangul, is_japanese
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--subset",
|
||||||
|
type=str,
|
||||||
|
help="Subset of emilia, (ZH, EN, etc.)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jobs",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="Number of jobs to processing.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--source-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/manifests_emilia/splits",
|
||||||
|
help="The source directory of manifest files.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dest-dir",
|
||||||
|
type=str,
|
||||||
|
help="The destination directory of manifest files.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_by_CJK_char(line: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Tokenize a line of text with CJK char.
|
||||||
|
|
||||||
|
Note: All return characters will be upper case.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
input = "你好世界是 hello world 的中文"
|
||||||
|
output = [你, 好, 世, 界, 是, HELLO, WORLD, 的, 中, 文]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
line:
|
||||||
|
The input text.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A new string tokenize by CJK char.
|
||||||
|
"""
|
||||||
|
# The CJK ranges is from https://github.com/alvations/nltk/blob/79eed6ddea0d0a2c212c1060b477fc268fec4d4b/nltk/tokenize/util.py
|
||||||
|
pattern = re.compile(
|
||||||
|
r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])"
|
||||||
|
)
|
||||||
|
chars = pattern.split(line.strip().upper())
|
||||||
|
char_list = []
|
||||||
|
for w in chars:
|
||||||
|
if w.strip():
|
||||||
|
char_list += w.strip().split()
|
||||||
|
return char_list
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_tokens_emilia(file_name: str, input_dir: Path, output_dir: Path):
|
||||||
|
logging.info(f"Processing {file_name}")
|
||||||
|
if (output_dir / file_name).is_file():
|
||||||
|
logging.info(f"{file_name} exists, skipping.")
|
||||||
|
return
|
||||||
|
jieba.setLogLevel(logging.INFO)
|
||||||
|
tokenizer = Tokenizer()
|
||||||
|
|
||||||
|
def _prepare_cut(cut):
|
||||||
|
# Each cut only contains one supervision
|
||||||
|
assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
|
||||||
|
text = cut.supervisions[0].text
|
||||||
|
cut.supervisions[0].normalized_text = text
|
||||||
|
tokens = tokenizer.texts_to_tokens([text])[0]
|
||||||
|
cut.tokens = tokens
|
||||||
|
return cut
|
||||||
|
|
||||||
|
def _filter_cut(cut):
|
||||||
|
text = cut.supervisions[0].text
|
||||||
|
duration = cut.supervisions[0].duration
|
||||||
|
chinese = []
|
||||||
|
english = []
|
||||||
|
|
||||||
|
# only contains chinese and space and alphabets
|
||||||
|
clean_chars = []
|
||||||
|
for x in text:
|
||||||
|
if is_hangul(x):
|
||||||
|
logging.info(f"Delete cut with text containing Korean : {text}")
|
||||||
|
return False
|
||||||
|
if is_japanese(x):
|
||||||
|
logging.info(f"Delete cut with text containing Japanese : {text}")
|
||||||
|
return False
|
||||||
|
if is_chinese(x):
|
||||||
|
chinese.append(x)
|
||||||
|
clean_chars.append(x)
|
||||||
|
if is_alphabet(x):
|
||||||
|
english.append(x)
|
||||||
|
clean_chars.append(x)
|
||||||
|
if x == " ":
|
||||||
|
clean_chars.append(x)
|
||||||
|
if len(english) + len(chinese) == 0:
|
||||||
|
logging.info(f"Delete cut with text has no valid chars : {text}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
words = tokenize_by_CJK_char("".join(clean_chars))
|
||||||
|
for i in range(len(words) - 10):
|
||||||
|
if words[i : i + 10].count(words[i]) == 10:
|
||||||
|
logging.info(f"Delete cut with text with too much repeats : {text}")
|
||||||
|
return False
|
||||||
|
# word speed, 20 - 600 / minute
|
||||||
|
if duration < len(words) / 600 * 60 or duration > len(words) / 20 * 60:
|
||||||
|
logging.info(
|
||||||
|
f"Delete cut with audio text mismatch, duration : {duration}s, words : {len(words)}, text : {text}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
cut_set = load_manifest_lazy(input_dir / file_name)
|
||||||
|
cut_set = cut_set.filter(_filter_cut)
|
||||||
|
cut_set = cut_set.map(_prepare_cut)
|
||||||
|
cut_set.to_file(output_dir / file_name)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Manifest {file_name} failed with error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
input_dir = Path(args.source_dir)
|
||||||
|
output_dir = Path(args.dest_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
cut_files = glob.glob(f"{args.source_dir}/emilia_cuts_{args.subset}.*.jsonl.gz")
|
||||||
|
|
||||||
|
with Pool(max_workers=args.jobs) as pool:
|
||||||
|
futures = [
|
||||||
|
pool.submit(
|
||||||
|
prepare_tokens_emilia, filename.split("/")[-1], input_dir, output_dir
|
||||||
|
)
|
||||||
|
for filename in cut_files
|
||||||
|
]
|
||||||
|
for f in futures:
|
||||||
|
try:
|
||||||
|
f.result()
|
||||||
|
f.done()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Future failed with error: {e}")
|
||||||
|
logging.info("Processing done.")
|
70
egs/zipvoice/local/validate_manifest.py
Executable file
70
egs/zipvoice/local/validate_manifest.py
Executable file
@ -0,0 +1,70 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script checks the following assumptions of the generated manifest:
|
||||||
|
|
||||||
|
- Single supervision per cut
|
||||||
|
|
||||||
|
We will add more checks later if needed.
|
||||||
|
|
||||||
|
Usage example:
|
||||||
|
|
||||||
|
python3 ./local/validate_manifest.py \
|
||||||
|
./data/spectrogram/ljspeech_cuts_all.jsonl.gz
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
|
from lhotse.dataset.speech_synthesis import validate_for_tts
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"manifest",
|
||||||
|
type=Path,
|
||||||
|
help="Path to the manifest file",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
manifest = args.manifest
|
||||||
|
logging.info(f"Validating {manifest}")
|
||||||
|
|
||||||
|
assert manifest.is_file(), f"{manifest} does not exist"
|
||||||
|
cut_set = load_manifest_lazy(manifest)
|
||||||
|
assert isinstance(cut_set, CutSet), type(cut_set)
|
||||||
|
|
||||||
|
validate_for_tts(cut_set)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
142
egs/zipvoice/zipvoice/checkpoint.py
Normal file
142
egs/zipvoice/zipvoice/checkpoint.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
|
from torch.cuda.amp import GradScaler
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
# use duck typing for LRScheduler since we have different possibilities, see
|
||||||
|
# our class LRScheduler.
|
||||||
|
LRSchedulerType = object
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(
|
||||||
|
filename: Path,
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
model_avg: Optional[nn.Module] = None,
|
||||||
|
model_ema: Optional[nn.Module] = None,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
optimizer: Optional[Optimizer] = None,
|
||||||
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
|
scaler: Optional[GradScaler] = None,
|
||||||
|
sampler: Optional[CutSampler] = None,
|
||||||
|
rank: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Save training information to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
The checkpoint filename.
|
||||||
|
model:
|
||||||
|
The model to be saved. We only save its `state_dict()`.
|
||||||
|
model_avg:
|
||||||
|
The stored model averaged from the start of training.
|
||||||
|
model_ema:
|
||||||
|
The EMA version of model.
|
||||||
|
params:
|
||||||
|
User defined parameters, e.g., epoch, loss.
|
||||||
|
optimizer:
|
||||||
|
The optimizer to be saved. We only save its `state_dict()`.
|
||||||
|
scheduler:
|
||||||
|
The scheduler to be saved. We only save its `state_dict()`.
|
||||||
|
scalar:
|
||||||
|
The GradScaler to be saved. We only save its `state_dict()`.
|
||||||
|
sampler:
|
||||||
|
The sampler used in the labeled training dataset. We only
|
||||||
|
save its `state_dict()`.
|
||||||
|
rank:
|
||||||
|
Used in DDP. We save checkpoint only for the node whose
|
||||||
|
rank is 0.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
if rank != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info(f"Saving checkpoint to {filename}")
|
||||||
|
|
||||||
|
if isinstance(model, DDP):
|
||||||
|
model = model.module
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
"model": model.state_dict(),
|
||||||
|
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
||||||
|
"scheduler": scheduler.state_dict() if scheduler is not None else None,
|
||||||
|
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
||||||
|
"sampler": sampler.state_dict() if sampler is not None else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if model_avg is not None:
|
||||||
|
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
|
||||||
|
if model_ema is not None:
|
||||||
|
checkpoint["model_ema"] = model_ema.to(torch.float32).state_dict()
|
||||||
|
|
||||||
|
if params:
|
||||||
|
for k, v in params.items():
|
||||||
|
assert k not in checkpoint
|
||||||
|
checkpoint[k] = v
|
||||||
|
|
||||||
|
torch.save(checkpoint, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(
|
||||||
|
filename: Path,
|
||||||
|
model: Optional[nn.Module] = None,
|
||||||
|
model_avg: Optional[nn.Module] = None,
|
||||||
|
model_ema: Optional[nn.Module] = None,
|
||||||
|
strict: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
logging.info(f"Loading checkpoint from {filename}")
|
||||||
|
checkpoint = torch.load(filename, map_location="cpu")
|
||||||
|
|
||||||
|
if model is not None:
|
||||||
|
|
||||||
|
if next(iter(checkpoint["model"])).startswith("module."):
|
||||||
|
logging.info("Loading checkpoint saved by DDP")
|
||||||
|
|
||||||
|
dst_state_dict = model.state_dict()
|
||||||
|
src_state_dict = checkpoint["model"]
|
||||||
|
for key in dst_state_dict.keys():
|
||||||
|
src_key = "{}.{}".format("module", key)
|
||||||
|
dst_state_dict[key] = src_state_dict.pop(src_key)
|
||||||
|
assert len(src_state_dict) == 0
|
||||||
|
model.load_state_dict(dst_state_dict, strict=strict)
|
||||||
|
else:
|
||||||
|
logging.info("Loading checkpoint")
|
||||||
|
model.load_state_dict(checkpoint["model"], strict=strict)
|
||||||
|
|
||||||
|
checkpoint.pop("model")
|
||||||
|
|
||||||
|
if model_avg is not None and "model_avg" in checkpoint:
|
||||||
|
logging.info("Loading averaged model")
|
||||||
|
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
|
||||||
|
checkpoint.pop("model_avg")
|
||||||
|
|
||||||
|
if model_ema is not None and "model_ema" in checkpoint:
|
||||||
|
logging.info("Loading ema model")
|
||||||
|
model_ema.load_state_dict(checkpoint["model_ema"], strict=strict)
|
||||||
|
checkpoint.pop("model_ema")
|
||||||
|
|
||||||
|
return checkpoint
|
135
egs/zipvoice/zipvoice/feature.py
Normal file
135
egs/zipvoice/zipvoice/feature.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchaudio
|
||||||
|
from lhotse.features.base import FeatureExtractor, register_extractor
|
||||||
|
from lhotse.utils import Seconds, compute_num_frames
|
||||||
|
|
||||||
|
|
||||||
|
class MelSpectrogramFeatures(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sampling_rate=24000,
|
||||||
|
n_mels=100,
|
||||||
|
n_fft=1024,
|
||||||
|
hop_length=256,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
||||||
|
sample_rate=sampling_rate,
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
n_mels=n_mels,
|
||||||
|
center=True,
|
||||||
|
power=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, inp):
|
||||||
|
assert len(inp.shape) == 2
|
||||||
|
|
||||||
|
mel = self.mel_spec(inp)
|
||||||
|
logmel = mel.clamp(min=1e-7).log()
|
||||||
|
return logmel
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TorchAudioFbankConfig:
|
||||||
|
sampling_rate: int
|
||||||
|
n_mels: int
|
||||||
|
n_fft: int
|
||||||
|
hop_length: int
|
||||||
|
|
||||||
|
|
||||||
|
@register_extractor
|
||||||
|
class TorchAudioFbank(FeatureExtractor):
|
||||||
|
|
||||||
|
name = "TorchAudioFbank"
|
||||||
|
config_type = TorchAudioFbankConfig
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
|
def _feature_fn(self, sample):
|
||||||
|
fbank = MelSpectrogramFeatures(
|
||||||
|
sampling_rate=self.config.sampling_rate,
|
||||||
|
n_mels=self.config.n_mels,
|
||||||
|
n_fft=self.config.n_fft,
|
||||||
|
hop_length=self.config.hop_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fbank(sample)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Union[str, torch.device]:
|
||||||
|
return self.config.device
|
||||||
|
|
||||||
|
def feature_dim(self, sampling_rate: int) -> int:
|
||||||
|
return self.config.n_mels
|
||||||
|
|
||||||
|
def extract(
|
||||||
|
self,
|
||||||
|
samples: Union[np.ndarray, torch.Tensor],
|
||||||
|
sampling_rate: int,
|
||||||
|
) -> Union[np.ndarray, torch.Tensor]:
|
||||||
|
# Check for sampling rate compatibility.
|
||||||
|
expected_sr = self.config.sampling_rate
|
||||||
|
assert sampling_rate == expected_sr, (
|
||||||
|
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
||||||
|
f"got {sampling_rate}"
|
||||||
|
)
|
||||||
|
is_numpy = False
|
||||||
|
if not isinstance(samples, torch.Tensor):
|
||||||
|
samples = torch.from_numpy(samples)
|
||||||
|
is_numpy = True
|
||||||
|
|
||||||
|
if len(samples.shape) == 1:
|
||||||
|
samples = samples.unsqueeze(0)
|
||||||
|
assert samples.ndim == 2, samples.shape
|
||||||
|
assert samples.shape[0] == 1, samples.shape
|
||||||
|
|
||||||
|
mel = self._feature_fn(samples).squeeze().t()
|
||||||
|
|
||||||
|
assert mel.ndim == 2, mel.shape
|
||||||
|
assert mel.shape[1] == self.config.n_mels, mel.shape
|
||||||
|
|
||||||
|
num_frames = compute_num_frames(
|
||||||
|
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
if mel.shape[0] > num_frames:
|
||||||
|
mel = mel[:num_frames]
|
||||||
|
elif mel.shape[0] < num_frames:
|
||||||
|
mel = mel.unsqueeze(0)
|
||||||
|
mel = torch.nn.functional.pad(
|
||||||
|
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
if is_numpy:
|
||||||
|
return mel.cpu().numpy()
|
||||||
|
else:
|
||||||
|
return mel
|
||||||
|
|
||||||
|
@property
|
||||||
|
def frame_shift(self) -> Seconds:
|
||||||
|
return self.config.hop_length / self.config.sampling_rate
|
209
egs/zipvoice/zipvoice/generate_averaged_model.py
Executable file
209
egs/zipvoice/zipvoice/generate_averaged_model.py
Executable file
@ -0,0 +1,209 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2022 Xiaomi Corporation
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
This script loads checkpoints and averages them.
|
||||||
|
|
||||||
|
(1) Average ZipVoice models before distill:
|
||||||
|
python3 ./zipvoice/generate_averaged_model.py \
|
||||||
|
--epoch 11 \
|
||||||
|
--avg 4 \
|
||||||
|
--distill 0 \
|
||||||
|
--token-file data/tokens_emilia.txt \
|
||||||
|
--exp-dir ./zipvoice/exp_zipvoice
|
||||||
|
|
||||||
|
It will generate a file `epoch-11-avg-14.pt` in the given `exp_dir`.
|
||||||
|
You can later load it by `torch.load("epoch-11-avg-4.pt")`.
|
||||||
|
|
||||||
|
(2) Average ZipVoice-Distill models (the first stage model):
|
||||||
|
|
||||||
|
python3 ./zipvoice/generate_averaged_model.py \
|
||||||
|
--iter 60000 \
|
||||||
|
--avg 7 \
|
||||||
|
--distill 1 \
|
||||||
|
--token-file data/tokens_emilia.txt \
|
||||||
|
--exp-dir ./zipvoice/exp_zipvoice_distill_1stage
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from model import get_distill_model, get_model
|
||||||
|
from tokenizer import TokenizerEmilia, TokenizerLibriTTS
|
||||||
|
from train_flow import add_model_arguments, get_params
|
||||||
|
|
||||||
|
from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=11,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' or --iter",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipvoice/exp_zipvoice",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--distill",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use distill model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
default="emilia",
|
||||||
|
choices=["emilia", "libritts"],
|
||||||
|
help="The used training dataset for the model to inference",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
if params.dataset == "emilia":
|
||||||
|
tokenizer = TokenizerEmilia(
|
||||||
|
token_file=params.token_file, token_type=params.token_type
|
||||||
|
)
|
||||||
|
elif params.dataset == "libritts":
|
||||||
|
tokenizer = TokenizerLibriTTS(
|
||||||
|
token_file=params.token_file, token_type=params.token_type
|
||||||
|
)
|
||||||
|
|
||||||
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
params.pad_id = tokenizer.pad_id
|
||||||
|
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
print("Script started")
|
||||||
|
|
||||||
|
params.device = torch.device("cpu")
|
||||||
|
print(f"Device: {params.device}")
|
||||||
|
|
||||||
|
print("About to create model")
|
||||||
|
if params.distill:
|
||||||
|
model = get_distill_model(params)
|
||||||
|
else:
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
print(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(params.device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=params.device,
|
||||||
|
),
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
print(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(params.device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=params.device,
|
||||||
|
),
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
if params.iter > 0:
|
||||||
|
filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
|
||||||
|
else:
|
||||||
|
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||||
|
torch.save({"model": model.state_dict()}, filename)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
586
egs/zipvoice/zipvoice/infer.py
Normal file
586
egs/zipvoice/zipvoice/infer.py
Normal file
@ -0,0 +1,586 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
|
||||||
|
# Han Zhu)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script loads checkpoints to generate waveforms.
|
||||||
|
This script is supposed to be used with the model trained by yourself.
|
||||||
|
If you want to use the pre-trained checkpoints provided by us, please refer to zipvoice_infer.py.
|
||||||
|
|
||||||
|
(1) Usage with a pre-trained checkpoint:
|
||||||
|
|
||||||
|
(a) ZipVoice model before distill:
|
||||||
|
python3 zipvoice/infer.py \
|
||||||
|
--checkpoint zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
|
||||||
|
--distill 0 \
|
||||||
|
--token-file "data/tokens_emilia.txt" \
|
||||||
|
--test-list test.tsv \
|
||||||
|
--res-dir results/test \
|
||||||
|
--num-step 16 \
|
||||||
|
--guidance-scale 1
|
||||||
|
|
||||||
|
(b) ZipVoice-Distill:
|
||||||
|
python3 zipvoice/infer.py \
|
||||||
|
--checkpoint zipvoice/exp_zipvoice_distill/checkpoint-2000.pt \
|
||||||
|
--distill 1 \
|
||||||
|
--token-file "data/tokens_emilia.txt" \
|
||||||
|
--test-list test.tsv \
|
||||||
|
--res-dir results/test_distill \
|
||||||
|
--num-step 8 \
|
||||||
|
--guidance-scale 3
|
||||||
|
|
||||||
|
(2) Usage with a directory of checkpoints (may requires checkpoint averaging):
|
||||||
|
|
||||||
|
(a) ZipVoice model before distill:
|
||||||
|
python3 flow_match/infer.py \
|
||||||
|
--exp-dir zipvoice/exp_zipvoice \
|
||||||
|
--epoch 11 \
|
||||||
|
--avg 4 \
|
||||||
|
--distill 0 \
|
||||||
|
--token-file "data/tokens_emilia.txt" \
|
||||||
|
--test-list test.tsv \
|
||||||
|
--res-dir results \
|
||||||
|
--num-step 16 \
|
||||||
|
--guidance-scale 1
|
||||||
|
|
||||||
|
(b) ZipVoice-Distill:
|
||||||
|
python3 flow_match/infer.py \
|
||||||
|
--exp-dir zipvoice/exp_zipvoice_distill/ \
|
||||||
|
--iter 2000 \
|
||||||
|
--avg 0 \
|
||||||
|
--distill 1 \
|
||||||
|
--token-file "data/tokens_emilia.txt" \
|
||||||
|
--test-list test.tsv \
|
||||||
|
--res-dir results \
|
||||||
|
--num-step 8 \
|
||||||
|
--guidance-scale 3
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime as dt
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchaudio
|
||||||
|
from checkpoint import load_checkpoint
|
||||||
|
from feature import TorchAudioFbank, TorchAudioFbankConfig
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from model import get_distill_model, get_model
|
||||||
|
from tokenizer import TokenizerEmilia, TokenizerLibriTTS
|
||||||
|
from train_flow import add_model_arguments, get_params
|
||||||
|
from vocos import Vocos
|
||||||
|
|
||||||
|
from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
|
||||||
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The checkpoint for inference. "
|
||||||
|
"If it is None, will use checkpoints under exp_dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipvoice/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' or '--iter', avg=0 means no avg",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--vocoder-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The local vocos vocoder path, downloaded from huggingface, "
|
||||||
|
"will download the vocodoer from huggingface if it is None.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--distill",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether it is a distilled TTS model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-list",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The list of prompt speech, prompt_transcription, "
|
||||||
|
"and text to synthesize in the format of "
|
||||||
|
"'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--res-dir",
|
||||||
|
type=str,
|
||||||
|
default="results",
|
||||||
|
help="Path name of the generated wavs dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset",
|
||||||
|
type=str,
|
||||||
|
default="emilia",
|
||||||
|
choices=["emilia", "libritts"],
|
||||||
|
help="The used training dataset for the model to inference",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--guidance-scale",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="The scale of classifier-free guidance during inference.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-step",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="The number of sampling steps.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--feat-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="The scale factor of fbank feature",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--speed",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Control speech speed, 1.0 means normal, >1.0 means speed up",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--t-shift",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="Shift t to smaller ones if t_shift < 1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-rms",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="Target speech normalization rms value",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=666,
|
||||||
|
help="Random seed",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_vocoder(vocos_local_path: Optional[str] = None):
|
||||||
|
if vocos_local_path:
|
||||||
|
vocos_local_path = "model/huggingface/vocos-mel-24khz/"
|
||||||
|
vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
||||||
|
state_dict = torch.load(
|
||||||
|
f"{vocos_local_path}/pytorch_model.bin",
|
||||||
|
weights_only=True,
|
||||||
|
map_location="cpu",
|
||||||
|
)
|
||||||
|
vocoder.load_state_dict(state_dict)
|
||||||
|
else:
|
||||||
|
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||||
|
return vocoder
|
||||||
|
|
||||||
|
|
||||||
|
def generate_sentence(
|
||||||
|
save_path: str,
|
||||||
|
prompt_text: str,
|
||||||
|
prompt_wav: str,
|
||||||
|
text: str,
|
||||||
|
model: nn.Module,
|
||||||
|
vocoder: nn.Module,
|
||||||
|
tokenizer: TokenizerEmilia,
|
||||||
|
feature_extractor: TorchAudioFbank,
|
||||||
|
device: torch.device,
|
||||||
|
num_step: int = 16,
|
||||||
|
guidance_scale: float = 1.0,
|
||||||
|
speed: float = 1.0,
|
||||||
|
t_shift: float = 0.5,
|
||||||
|
target_rms: float = 0.1,
|
||||||
|
feat_scale: float = 0.1,
|
||||||
|
sampling_rate: int = 24000,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate waveform of a text based on a given prompt
|
||||||
|
waveform and its transcription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_path (str): Path to save the generated wav.
|
||||||
|
prompt_text (str): Transcription of the prompt wav.
|
||||||
|
prompt_wav (str): Path to the prompt wav file.
|
||||||
|
text (str): Text to be synthesized into a waveform.
|
||||||
|
model (nn.Module): The model used for generation.
|
||||||
|
vocoder (nn.Module): The vocoder used to convert features to waveforms.
|
||||||
|
tokenizer (TokenizerEmilia): The tokenizer used to convert text to tokens.
|
||||||
|
feature_extractor (TorchAudioFbank): The feature extractor used to
|
||||||
|
extract acoustic features.
|
||||||
|
device (torch.device): The device on which computations are performed.
|
||||||
|
num_step (int, optional): Number of steps for decoding. Defaults to 16.
|
||||||
|
guidance_scale (float, optional): Scale for classifier-free guidance.
|
||||||
|
Defaults to 1.0.
|
||||||
|
speed (float, optional): Speed control. Defaults to 1.0.
|
||||||
|
t_shift (float, optional): Time shift. Defaults to 0.5.
|
||||||
|
target_rms (float, optional): Target RMS for waveform normalization.
|
||||||
|
Defaults to 0.1.
|
||||||
|
feat_scale (float, optional): Scale for features.
|
||||||
|
Defaults to 0.1.
|
||||||
|
sampling_rate (int, optional): Sampling rate for the waveform.
|
||||||
|
Defaults to 24000.
|
||||||
|
Returns:
|
||||||
|
metrics (dict): Dictionary containing time and real-time
|
||||||
|
factor metrics for processing.
|
||||||
|
"""
|
||||||
|
# Convert text to tokens
|
||||||
|
tokens = tokenizer.texts_to_token_ids([text])
|
||||||
|
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
|
||||||
|
|
||||||
|
# Load and preprocess prompt wav
|
||||||
|
prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
|
||||||
|
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
|
||||||
|
if prompt_rms < target_rms:
|
||||||
|
prompt_wav = prompt_wav * target_rms / prompt_rms
|
||||||
|
|
||||||
|
if prompt_sampling_rate != sampling_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(
|
||||||
|
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
|
||||||
|
)
|
||||||
|
prompt_wav = resampler(prompt_wav)
|
||||||
|
|
||||||
|
# Extract features from prompt wav
|
||||||
|
prompt_features = feature_extractor.extract(
|
||||||
|
prompt_wav, sampling_rate=sampling_rate
|
||||||
|
).to(device)
|
||||||
|
prompt_features = prompt_features.unsqueeze(0) * feat_scale
|
||||||
|
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
|
||||||
|
|
||||||
|
# Start timing
|
||||||
|
start_t = dt.datetime.now()
|
||||||
|
|
||||||
|
# Generate features
|
||||||
|
(
|
||||||
|
pred_features,
|
||||||
|
pred_features_lens,
|
||||||
|
pred_prompt_features,
|
||||||
|
pred_prompt_features_lens,
|
||||||
|
) = model.sample(
|
||||||
|
tokens=tokens,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
prompt_features=prompt_features,
|
||||||
|
prompt_features_lens=prompt_features_lens,
|
||||||
|
speed=speed,
|
||||||
|
t_shift=t_shift,
|
||||||
|
duration="predict",
|
||||||
|
num_step=num_step,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Postprocess predicted features
|
||||||
|
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
|
||||||
|
|
||||||
|
# Start vocoder processing
|
||||||
|
start_vocoder_t = dt.datetime.now()
|
||||||
|
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
|
||||||
|
|
||||||
|
# Calculate processing times and real-time factors
|
||||||
|
t = (dt.datetime.now() - start_t).total_seconds()
|
||||||
|
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
|
||||||
|
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
|
||||||
|
wav_seconds = wav.shape[-1] / sampling_rate
|
||||||
|
rtf = t / wav_seconds
|
||||||
|
rtf_no_vocoder = t_no_vocoder / wav_seconds
|
||||||
|
rtf_vocoder = t_vocoder / wav_seconds
|
||||||
|
metrics = {
|
||||||
|
"t": t,
|
||||||
|
"t_no_vocoder": t_no_vocoder,
|
||||||
|
"t_vocoder": t_vocoder,
|
||||||
|
"wav_seconds": wav_seconds,
|
||||||
|
"rtf": rtf,
|
||||||
|
"rtf_no_vocoder": rtf_no_vocoder,
|
||||||
|
"rtf_vocoder": rtf_vocoder,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Adjust wav volume if necessary
|
||||||
|
if prompt_rms < target_rms:
|
||||||
|
wav = wav * prompt_rms / target_rms
|
||||||
|
wav = wav[0].cpu().numpy()
|
||||||
|
sf.write(save_path, wav, sampling_rate)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
vocoder: nn.Module,
|
||||||
|
tokenizer: TokenizerEmilia,
|
||||||
|
):
|
||||||
|
total_t = []
|
||||||
|
total_t_no_vocoder = []
|
||||||
|
total_t_vocoder = []
|
||||||
|
total_wav_seconds = []
|
||||||
|
|
||||||
|
config = TorchAudioFbankConfig(
|
||||||
|
sampling_rate=params.sampling_rate,
|
||||||
|
n_mels=100,
|
||||||
|
n_fft=1024,
|
||||||
|
hop_length=256,
|
||||||
|
)
|
||||||
|
feature_extractor = TorchAudioFbank(config)
|
||||||
|
|
||||||
|
with open(params.test_list, "r") as fr:
|
||||||
|
lines = fr.readlines()
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
|
||||||
|
save_path = f"{params.wav_dir}/{wav_name}.wav"
|
||||||
|
metrics = generate_sentence(
|
||||||
|
save_path=save_path,
|
||||||
|
prompt_text=prompt_text,
|
||||||
|
prompt_wav=prompt_wav,
|
||||||
|
text=text,
|
||||||
|
model=model,
|
||||||
|
vocoder=vocoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
device=params.device,
|
||||||
|
num_step=params.num_step,
|
||||||
|
guidance_scale=params.guidance_scale,
|
||||||
|
speed=params.speed,
|
||||||
|
t_shift=params.t_shift,
|
||||||
|
target_rms=params.target_rms,
|
||||||
|
feat_scale=params.feat_scale,
|
||||||
|
sampling_rate=params.sampling_rate,
|
||||||
|
)
|
||||||
|
print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
|
||||||
|
total_t.append(metrics["t"])
|
||||||
|
total_t_no_vocoder.append(metrics["t_no_vocoder"])
|
||||||
|
total_t_vocoder.append(metrics["t_vocoder"])
|
||||||
|
total_wav_seconds.append(metrics["wav_seconds"])
|
||||||
|
|
||||||
|
print(f"Average RTF: " f"{np.sum(total_t)/np.sum(total_wav_seconds):.4f}")
|
||||||
|
print(
|
||||||
|
f"Average RTF w/o vocoder: "
|
||||||
|
f"{np.sum(total_t_no_vocoder)/np.sum(total_wav_seconds):.4f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Average RTF vocoder: "
|
||||||
|
f"{np.sum(total_t_vocoder)/np.sum(total_wav_seconds):.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = (
|
||||||
|
f"wavs-iter-{params.iter}-avg"
|
||||||
|
f"-{params.avg}-step-{params.num_step}-scale-{params.guidance_scale}"
|
||||||
|
)
|
||||||
|
elif params.epoch > 0:
|
||||||
|
params.suffix = (
|
||||||
|
f"wavs-epoch-{params.epoch}-avg"
|
||||||
|
f"-{params.avg}-step-{params.num_step}-scale-{params.guidance_scale}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
params.suffix = "wavs"
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
params.device = torch.device("cuda", 0)
|
||||||
|
else:
|
||||||
|
params.device = torch.device("cpu")
|
||||||
|
|
||||||
|
logging.info(f"Device: {params.device}")
|
||||||
|
|
||||||
|
if params.dataset == "emilia":
|
||||||
|
tokenizer = TokenizerEmilia(
|
||||||
|
token_file=params.token_file, token_type=params.token_type
|
||||||
|
)
|
||||||
|
elif params.dataset == "libritts":
|
||||||
|
tokenizer = TokenizerLibriTTS(
|
||||||
|
token_file=params.token_file, token_type=params.token_type
|
||||||
|
)
|
||||||
|
|
||||||
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
params.pad_id = tokenizer.pad_id
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
fix_random_seed(params.seed)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
if params.distill:
|
||||||
|
model = get_distill_model(params)
|
||||||
|
else:
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
if params.checkpoint:
|
||||||
|
load_checkpoint(params.checkpoint, model, strict=True)
|
||||||
|
else:
|
||||||
|
if params.avg == 0:
|
||||||
|
if params.iter > 0:
|
||||||
|
load_checkpoint(
|
||||||
|
f"{params.exp_dir}/checkpoint-{params.iter}.pt", model, strict=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
load_checkpoint(
|
||||||
|
f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(params.device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=params.device,
|
||||||
|
),
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(params.device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=params.device,
|
||||||
|
),
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = model.to(params.device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
vocoder = get_vocoder(params.vocoder_path)
|
||||||
|
vocoder = vocoder.to(params.device)
|
||||||
|
vocoder.eval()
|
||||||
|
num_param = sum([p.numel() for p in vocoder.parameters()])
|
||||||
|
logging.info(f"Number of vocoder parameters: {num_param}")
|
||||||
|
|
||||||
|
params.wav_dir = f"{params.res_dir}/{params.suffix}"
|
||||||
|
os.makedirs(params.wav_dir, exist_ok=True)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
params.test_list is not None
|
||||||
|
), "Please provide --test-list for speech synthesize."
|
||||||
|
generate(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
vocoder=vocoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
main()
|
578
egs/zipvoice/zipvoice/model.py
Normal file
578
egs/zipvoice/zipvoice/model.py
Normal file
@ -0,0 +1,578 @@
|
|||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
|
||||||
|
# Han Zhu)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from scaling import ScheduledFloat
|
||||||
|
from solver import EulerSolver
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from utils import (
|
||||||
|
AttributeDict,
|
||||||
|
condition_time_mask,
|
||||||
|
get_tokens_index,
|
||||||
|
make_pad_mask,
|
||||||
|
pad_labels,
|
||||||
|
prepare_avg_tokens_durations,
|
||||||
|
to_int_tuple,
|
||||||
|
)
|
||||||
|
from zipformer import TTSZipformer
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
|
"""Get the normal TTS model."""
|
||||||
|
|
||||||
|
fm_decoder = get_fm_decoder_model(params)
|
||||||
|
text_encoder = get_text_encoder_model(params)
|
||||||
|
|
||||||
|
model = TtsModel(
|
||||||
|
fm_decoder=fm_decoder,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
text_embed_dim=params.text_embed_dim,
|
||||||
|
feat_dim=params.feat_dim,
|
||||||
|
vocab_size=params.vocab_size,
|
||||||
|
pad_id=params.pad_id,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_distill_model(params: AttributeDict) -> nn.Module:
|
||||||
|
"""Get the distillation TTS model."""
|
||||||
|
|
||||||
|
fm_decoder = get_fm_decoder_model(params, distill=True)
|
||||||
|
text_encoder = get_text_encoder_model(params)
|
||||||
|
|
||||||
|
model = DistillTTSModelTrainWrapper(
|
||||||
|
fm_decoder=fm_decoder,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
text_embed_dim=params.text_embed_dim,
|
||||||
|
feat_dim=params.feat_dim,
|
||||||
|
vocab_size=params.vocab_size,
|
||||||
|
pad_id=params.pad_id,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_fm_decoder_model(params: AttributeDict, distill: bool = False) -> nn.Module:
|
||||||
|
"""Get the Zipformer-based FM decoder model."""
|
||||||
|
|
||||||
|
encoder = TTSZipformer(
|
||||||
|
in_dim=params.feat_dim * 3,
|
||||||
|
out_dim=params.feat_dim,
|
||||||
|
downsampling_factor=to_int_tuple(params.fm_decoder_downsampling_factor),
|
||||||
|
num_encoder_layers=to_int_tuple(params.fm_decoder_num_layers),
|
||||||
|
cnn_module_kernel=to_int_tuple(params.fm_decoder_cnn_module_kernel),
|
||||||
|
encoder_dim=params.fm_decoder_dim,
|
||||||
|
feedforward_dim=params.fm_decoder_feedforward_dim,
|
||||||
|
num_heads=params.fm_decoder_num_heads,
|
||||||
|
query_head_dim=params.query_head_dim,
|
||||||
|
pos_head_dim=params.pos_head_dim,
|
||||||
|
value_head_dim=params.value_head_dim,
|
||||||
|
pos_dim=params.pos_dim,
|
||||||
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||||
|
warmup_batches=4000.0,
|
||||||
|
use_time_embed=True,
|
||||||
|
time_embed_dim=192,
|
||||||
|
use_guidance_scale_embed=distill,
|
||||||
|
)
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
def get_text_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
|
"""Get the Zipformer-based text encoder model."""
|
||||||
|
|
||||||
|
encoder = TTSZipformer(
|
||||||
|
in_dim=params.text_embed_dim,
|
||||||
|
out_dim=params.feat_dim,
|
||||||
|
downsampling_factor=to_int_tuple(params.text_encoder_downsampling_factor),
|
||||||
|
num_encoder_layers=to_int_tuple(params.text_encoder_num_layers),
|
||||||
|
cnn_module_kernel=to_int_tuple(params.text_encoder_cnn_module_kernel),
|
||||||
|
encoder_dim=params.text_encoder_dim,
|
||||||
|
feedforward_dim=params.text_encoder_feedforward_dim,
|
||||||
|
num_heads=params.text_encoder_num_heads,
|
||||||
|
query_head_dim=params.query_head_dim,
|
||||||
|
pos_head_dim=params.pos_head_dim,
|
||||||
|
value_head_dim=params.value_head_dim,
|
||||||
|
pos_dim=params.pos_dim,
|
||||||
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||||
|
warmup_batches=4000.0,
|
||||||
|
use_time_embed=False,
|
||||||
|
)
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
class TtsModel(nn.Module):
|
||||||
|
"""The normal TTS model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fm_decoder: nn.Module,
|
||||||
|
text_encoder: nn.Module,
|
||||||
|
text_embed_dim: int,
|
||||||
|
feat_dim: int,
|
||||||
|
vocab_size: int,
|
||||||
|
pad_id: int = 0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
fm_decoder: the flow-matching encoder model, inputs are the
|
||||||
|
input condition embeddings and noisy acoustic features,
|
||||||
|
outputs are better acoustic features.
|
||||||
|
text_encoder: the text encoder model. input are text
|
||||||
|
embeddings, output are contextualized text embeddings.
|
||||||
|
text_embed_dim: dimension of text embedding.
|
||||||
|
feat_dim: dimension of acoustic features.
|
||||||
|
vocab_size: vocabulary size.
|
||||||
|
pad_id: padding id.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.feat_dim = feat_dim
|
||||||
|
self.text_embed_dim = text_embed_dim
|
||||||
|
self.pad_id = pad_id
|
||||||
|
|
||||||
|
self.fm_decoder = fm_decoder
|
||||||
|
|
||||||
|
self.text_encoder = text_encoder
|
||||||
|
|
||||||
|
self.embed = nn.Embedding(vocab_size, text_embed_dim)
|
||||||
|
|
||||||
|
self.distill = False
|
||||||
|
|
||||||
|
def forward_fm_decoder(
|
||||||
|
self,
|
||||||
|
t: torch.Tensor,
|
||||||
|
xt: torch.Tensor,
|
||||||
|
text_condition: torch.Tensor,
|
||||||
|
speech_condition: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
guidance_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute velocity.
|
||||||
|
Args:
|
||||||
|
t: A tensor of shape (N, 1, 1) or a tensor of a float,
|
||||||
|
in the range of (0, 1).
|
||||||
|
xt: the input of the current timestep, including condition
|
||||||
|
embeddings and noisy acoustic features.
|
||||||
|
text_condition: the text condition embeddings, with the
|
||||||
|
shape (batch, seq_len, emb_dim).
|
||||||
|
speech_condition: the speech condition embeddings, with the
|
||||||
|
shape (batch, seq_len, emb_dim).
|
||||||
|
padding_mask: The mask for padding, True means masked
|
||||||
|
position, with the shape (N, T).
|
||||||
|
guidance_scale: The guidance scale in classifier-free guidance,
|
||||||
|
which is a tensor of shape (N, 1, 1) or a tensor of a float.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
predicted velocity, with the shape (batch, seq_len, emb_dim).
|
||||||
|
"""
|
||||||
|
assert t.dim() in (0, 3)
|
||||||
|
# Handle t with the shape (N, 1, 1):
|
||||||
|
# squeeze the last dimension if it's size is 1.
|
||||||
|
while t.dim() > 1 and t.size(-1) == 1:
|
||||||
|
t = t.squeeze(-1)
|
||||||
|
if guidance_scale is not None:
|
||||||
|
while guidance_scale.dim() > 1 and guidance_scale.size(-1) == 1:
|
||||||
|
guidance_scale = guidance_scale.squeeze(-1)
|
||||||
|
# Handle t with a single value: expand to the size of batch size.
|
||||||
|
if t.dim() == 0:
|
||||||
|
t = t.repeat(xt.shape[0])
|
||||||
|
if guidance_scale is not None and guidance_scale.dim() == 0:
|
||||||
|
guidance_scale = guidance_scale.repeat(xt.shape[0])
|
||||||
|
|
||||||
|
xt = torch.cat([xt, text_condition, speech_condition], dim=2)
|
||||||
|
vt = self.fm_decoder(
|
||||||
|
x=xt, t=t, padding_mask=padding_mask, guidance_scale=guidance_scale
|
||||||
|
)
|
||||||
|
return vt
|
||||||
|
|
||||||
|
def forward_text_embed(
|
||||||
|
self,
|
||||||
|
tokens: List[List[int]],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get the text embeddings.
|
||||||
|
Args:
|
||||||
|
tokens: a list of list of token ids.
|
||||||
|
Returns:
|
||||||
|
embed: the text embeddings, shape (batch, seq_len, emb_dim).
|
||||||
|
tokens_lens: the length of each token sequence, shape (batch,).
|
||||||
|
"""
|
||||||
|
device = (
|
||||||
|
self.device if isinstance(self, DDP) else next(self.parameters()).device
|
||||||
|
)
|
||||||
|
tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S)
|
||||||
|
embed = self.embed(tokens_padded) # (B, S, C)
|
||||||
|
tokens_lens = torch.tensor(
|
||||||
|
[len(token) for token in tokens], dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S)
|
||||||
|
|
||||||
|
embed = self.text_encoder(
|
||||||
|
x=embed, t=None, padding_mask=tokens_padding_mask
|
||||||
|
) # (B, S, C)
|
||||||
|
return embed, tokens_lens
|
||||||
|
|
||||||
|
def forward_text_condition(
|
||||||
|
self,
|
||||||
|
embed: torch.Tensor,
|
||||||
|
tokens_lens: torch.Tensor,
|
||||||
|
features_lens: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get the text condition with the same length of the acoustic feature.
|
||||||
|
Args:
|
||||||
|
embed: the text embeddings, shape (batch, token_seq_len, emb_dim).
|
||||||
|
tokens_lens: the length of each token sequence, shape (batch,).
|
||||||
|
features_lens: the length of each acoustic feature sequence,
|
||||||
|
shape (batch,).
|
||||||
|
Returns:
|
||||||
|
text_condition: the text condition, shape
|
||||||
|
(batch, feature_seq_len, emb_dim).
|
||||||
|
padding_mask: the padding mask of text condition, shape
|
||||||
|
(batch, feature_seq_len).
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_frames = int(features_lens.max())
|
||||||
|
|
||||||
|
padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T)
|
||||||
|
|
||||||
|
tokens_durations = prepare_avg_tokens_durations(features_lens, tokens_lens)
|
||||||
|
|
||||||
|
tokens_index = get_tokens_index(tokens_durations, num_frames).to(
|
||||||
|
embed.device
|
||||||
|
) # (B, T)
|
||||||
|
|
||||||
|
text_condition = torch.gather(
|
||||||
|
embed,
|
||||||
|
dim=1,
|
||||||
|
index=tokens_index.unsqueeze(-1).expand(
|
||||||
|
embed.size(0), num_frames, embed.size(-1)
|
||||||
|
),
|
||||||
|
) # (B, T, F)
|
||||||
|
return text_condition, padding_mask
|
||||||
|
|
||||||
|
def forward_text_train(
|
||||||
|
self,
|
||||||
|
tokens: List[List[int]],
|
||||||
|
features_lens: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Process text for training, given text tokens and real feature lengths.
|
||||||
|
"""
|
||||||
|
embed, tokens_lens = self.forward_text_embed(tokens)
|
||||||
|
text_condition, padding_mask = self.forward_text_condition(
|
||||||
|
embed, tokens_lens, features_lens
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
text_condition,
|
||||||
|
padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_text_inference_gt_duration(
|
||||||
|
self,
|
||||||
|
tokens: List[List[int]],
|
||||||
|
features_lens: torch.Tensor,
|
||||||
|
prompt_tokens: List[List[int]],
|
||||||
|
prompt_features_lens: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Process text for inference, given text tokens, real feature lengths and prompts.
|
||||||
|
"""
|
||||||
|
tokens = [
|
||||||
|
prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
|
||||||
|
]
|
||||||
|
features_lens = prompt_features_lens + features_lens
|
||||||
|
embed, tokens_lens = self.forward_text_embed(tokens)
|
||||||
|
text_condition, padding_mask = self.forward_text_condition(
|
||||||
|
embed, tokens_lens, features_lens
|
||||||
|
)
|
||||||
|
return text_condition, padding_mask
|
||||||
|
|
||||||
|
def forward_text_inference_ratio_duration(
|
||||||
|
self,
|
||||||
|
tokens: List[List[int]],
|
||||||
|
prompt_tokens: List[List[int]],
|
||||||
|
prompt_features_lens: torch.Tensor,
|
||||||
|
speed: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Process text for inference, given text tokens and prompts,
|
||||||
|
feature lengths are predicted with the ratio of token numbers.
|
||||||
|
"""
|
||||||
|
device = (
|
||||||
|
self.device if isinstance(self, DDP) else next(self.parameters()).device
|
||||||
|
)
|
||||||
|
|
||||||
|
cat_tokens = [
|
||||||
|
prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt_tokens_lens = torch.tensor(
|
||||||
|
[len(token) for token in prompt_tokens], dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
cat_embed, cat_tokens_lens = self.forward_text_embed(cat_tokens)
|
||||||
|
|
||||||
|
features_lens = torch.ceil(
|
||||||
|
(prompt_features_lens / prompt_tokens_lens * cat_tokens_lens / speed)
|
||||||
|
).to(dtype=torch.int64)
|
||||||
|
|
||||||
|
text_condition, padding_mask = self.forward_text_condition(
|
||||||
|
cat_embed, cat_tokens_lens, features_lens
|
||||||
|
)
|
||||||
|
return text_condition, padding_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tokens: List[List[int]],
|
||||||
|
features: torch.Tensor,
|
||||||
|
features_lens: torch.Tensor,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
condition_drop_ratio: float = 0.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Forward pass of the model for training.
|
||||||
|
Args:
|
||||||
|
tokens: a list of list of token ids.
|
||||||
|
features: the acoustic features, with the shape (batch, seq_len, feat_dim).
|
||||||
|
features_lens: the length of each acoustic feature sequence, shape (batch,).
|
||||||
|
noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
|
||||||
|
t: the time step, with the shape (batch, 1, 1).
|
||||||
|
condition_drop_ratio: the ratio of dropped text condition.
|
||||||
|
Returns:
|
||||||
|
fm_loss: the flow-matching loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
(text_condition, padding_mask,) = self.forward_text_train(
|
||||||
|
tokens=tokens,
|
||||||
|
features_lens=features_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
speech_condition_mask = condition_time_mask(
|
||||||
|
features_lens=features_lens,
|
||||||
|
mask_percent=(0.7, 1.0),
|
||||||
|
max_len=features.size(1),
|
||||||
|
)
|
||||||
|
speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
|
||||||
|
|
||||||
|
if condition_drop_ratio > 0.0:
|
||||||
|
drop_mask = (
|
||||||
|
torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
|
||||||
|
> condition_drop_ratio
|
||||||
|
)
|
||||||
|
text_condition = text_condition * drop_mask
|
||||||
|
|
||||||
|
xt = features * t + noise * (1 - t)
|
||||||
|
ut = features - noise # (B, T, F)
|
||||||
|
|
||||||
|
vt = self.forward_fm_decoder(
|
||||||
|
t=t,
|
||||||
|
xt=xt,
|
||||||
|
text_condition=text_condition,
|
||||||
|
speech_condition=speech_condition,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_mask = speech_condition_mask & (~padding_mask)
|
||||||
|
fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
|
||||||
|
|
||||||
|
return fm_loss
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
tokens: List[List[int]],
|
||||||
|
prompt_tokens: List[List[int]],
|
||||||
|
prompt_features: torch.Tensor,
|
||||||
|
prompt_features_lens: torch.Tensor,
|
||||||
|
features_lens: Optional[torch.Tensor] = None,
|
||||||
|
speed: float = 1.0,
|
||||||
|
t_shift: float = 1.0,
|
||||||
|
duration: str = "predict",
|
||||||
|
num_step: int = 5,
|
||||||
|
guidance_scale: float = 0.5,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Generate acoustic features, given text tokens, prompts feature
|
||||||
|
and prompt transcription's text tokens.
|
||||||
|
Args:
|
||||||
|
tokens: a list of list of text tokens.
|
||||||
|
prompt_tokens: a list of list of prompt tokens.
|
||||||
|
prompt_features: the prompt feature with the shape
|
||||||
|
(batch_size, seq_len, feat_dim).
|
||||||
|
prompt_features_lens: the length of each prompt feature,
|
||||||
|
with the shape (batch_size,).
|
||||||
|
features_lens: the length of the predicted eature, with the
|
||||||
|
shape (batch_size,). It is used only when duration is "real".
|
||||||
|
duration: "real" or "predict". If "real", the predicted
|
||||||
|
feature length is given by features_lens.
|
||||||
|
num_step: the number of steps to use in the ODE solver.
|
||||||
|
guidance_scale: the guidance scale for classifier-free guidance.
|
||||||
|
distill: whether to use the distillation model for sampling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert duration in ["real", "predict"]
|
||||||
|
|
||||||
|
if duration == "predict":
|
||||||
|
(
|
||||||
|
text_condition,
|
||||||
|
padding_mask,
|
||||||
|
) = self.forward_text_inference_ratio_duration(
|
||||||
|
tokens=tokens,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
prompt_features_lens=prompt_features_lens,
|
||||||
|
speed=speed,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert features_lens is not None
|
||||||
|
text_condition, padding_mask = self.forward_text_inference_gt_duration(
|
||||||
|
tokens=tokens,
|
||||||
|
features_lens=features_lens,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
prompt_features_lens=prompt_features_lens,
|
||||||
|
)
|
||||||
|
batch_size, num_frames, _ = text_condition.shape
|
||||||
|
|
||||||
|
speech_condition = torch.nn.functional.pad(
|
||||||
|
prompt_features, (0, 0, 0, num_frames - prompt_features.size(1))
|
||||||
|
) # (B, T, F)
|
||||||
|
|
||||||
|
# False means speech condition positions.
|
||||||
|
speech_condition_mask = make_pad_mask(prompt_features_lens, num_frames)
|
||||||
|
speech_condition = torch.where(
|
||||||
|
speech_condition_mask.unsqueeze(-1),
|
||||||
|
torch.zeros_like(speech_condition),
|
||||||
|
speech_condition,
|
||||||
|
)
|
||||||
|
|
||||||
|
x0 = torch.randn(
|
||||||
|
batch_size, num_frames, self.feat_dim, device=text_condition.device
|
||||||
|
)
|
||||||
|
solver = EulerSolver(self, distill=self.distill, func_name="forward_fm_decoder")
|
||||||
|
|
||||||
|
x1 = solver.sample(
|
||||||
|
x=x0,
|
||||||
|
text_condition=text_condition,
|
||||||
|
speech_condition=speech_condition,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
num_step=num_step,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
t_shift=t_shift,
|
||||||
|
)
|
||||||
|
x1_wo_prompt_lens = (~padding_mask).sum(-1) - prompt_features_lens
|
||||||
|
x1_prompt = torch.zeros(
|
||||||
|
x1.size(0), prompt_features_lens.max(), x1.size(2), device=x1.device
|
||||||
|
)
|
||||||
|
x1_wo_prompt = torch.zeros(
|
||||||
|
x1.size(0), x1_wo_prompt_lens.max(), x1.size(2), device=x1.device
|
||||||
|
)
|
||||||
|
for i in range(x1.size(0)):
|
||||||
|
x1_wo_prompt[i, : x1_wo_prompt_lens[i], :] = x1[
|
||||||
|
i,
|
||||||
|
prompt_features_lens[i] : prompt_features_lens[i]
|
||||||
|
+ x1_wo_prompt_lens[i],
|
||||||
|
]
|
||||||
|
x1_prompt[i, : prompt_features_lens[i], :] = x1[
|
||||||
|
i, : prompt_features_lens[i]
|
||||||
|
]
|
||||||
|
|
||||||
|
return x1_wo_prompt, x1_wo_prompt_lens, x1_prompt, prompt_features_lens
|
||||||
|
|
||||||
|
def sample_intermediate(
|
||||||
|
self,
|
||||||
|
tokens: List[List[int]],
|
||||||
|
features: torch.Tensor,
|
||||||
|
features_lens: torch.Tensor,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
speech_condition_mask: torch.Tensor,
|
||||||
|
t_start: torch.Tensor,
|
||||||
|
t_end: torch.Tensor,
|
||||||
|
num_step: int = 1,
|
||||||
|
guidance_scale: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Generate acoustic features in intermediate timesteps.
|
||||||
|
Args:
|
||||||
|
tokens: List of list of token ids.
|
||||||
|
features: The acoustic features, with the shape (batch, seq_len, feat_dim).
|
||||||
|
features_lens: The length of each acoustic feature sequence,
|
||||||
|
with the shape (batch,).
|
||||||
|
noise: The initial noise, with the shape (batch, seq_len, feat_dim).
|
||||||
|
speech_condition_mask: The mask for speech condition, True means
|
||||||
|
non-condition positions, with the shape (batch, seq_len).
|
||||||
|
t_start: The start timestep, with the shape (batch, 1, 1).
|
||||||
|
t_end: The end timestep, with the shape (batch, 1, 1).
|
||||||
|
num_step: The number of steps for sampling.
|
||||||
|
guidance_scale: The scale for classifier-free guidance inference,
|
||||||
|
with the shape (batch, 1, 1).
|
||||||
|
distill: Whether to use distillation model.
|
||||||
|
"""
|
||||||
|
(text_condition, padding_mask,) = self.forward_text_train(
|
||||||
|
tokens=tokens,
|
||||||
|
features_lens=features_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
|
||||||
|
|
||||||
|
solver = EulerSolver(self, distill=self.distill, func_name="forward_fm_decoder")
|
||||||
|
|
||||||
|
x_t_end = solver.sample(
|
||||||
|
x=noise,
|
||||||
|
text_condition=text_condition,
|
||||||
|
speech_condition=speech_condition,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
num_step=num_step,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
t_start=t_start,
|
||||||
|
t_end=t_end,
|
||||||
|
)
|
||||||
|
x_t_end_lens = (~padding_mask).sum(-1)
|
||||||
|
return x_t_end, x_t_end_lens
|
||||||
|
|
||||||
|
|
||||||
|
class DistillTTSModelTrainWrapper(TtsModel):
|
||||||
|
"""Wrapper for training the distilled TTS model."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.distill = True
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tokens: List[List[int]],
|
||||||
|
features: torch.Tensor,
|
||||||
|
features_lens: torch.Tensor,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
speech_condition_mask: torch.Tensor,
|
||||||
|
t_start: torch.Tensor,
|
||||||
|
t_end: torch.Tensor,
|
||||||
|
num_step: int = 1,
|
||||||
|
guidance_scale: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
return self.sample_intermediate(
|
||||||
|
tokens=tokens,
|
||||||
|
features=features,
|
||||||
|
features_lens=features_lens,
|
||||||
|
noise=noise,
|
||||||
|
speech_condition_mask=speech_condition_mask,
|
||||||
|
t_start=t_start,
|
||||||
|
t_end=t_end,
|
||||||
|
num_step=num_step,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
)
|
1256
egs/zipvoice/zipvoice/optim.py
Normal file
1256
egs/zipvoice/zipvoice/optim.py
Normal file
File diff suppressed because it is too large
Load Diff
1910
egs/zipvoice/zipvoice/scaling.py
Normal file
1910
egs/zipvoice/zipvoice/scaling.py
Normal file
File diff suppressed because it is too large
Load Diff
277
egs/zipvoice/zipvoice/solver.py
Normal file
277
egs/zipvoice/zipvoice/solver.py
Normal file
@ -0,0 +1,277 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionModel(torch.nn.Module):
|
||||||
|
"""A wrapper of diffusion models for inference.
|
||||||
|
Args:
|
||||||
|
model: The diffusion model.
|
||||||
|
distill: Whether it is a distillation model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
distill: bool = False,
|
||||||
|
func_name: str = "forward_fm_decoder",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.distill = distill
|
||||||
|
self.func_name = func_name
|
||||||
|
self.model_func = getattr(self.model, func_name)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
t: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
text_condition: torch.Tensor,
|
||||||
|
speech_condition: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
guidance_scale: Union[float, torch.Tensor] = 0.0,
|
||||||
|
**kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward function that Handles the classifier-free guidance.
|
||||||
|
Args:
|
||||||
|
t: The current timestep, a tensor of shape (batch, 1, 1) or a tensor of a single float.
|
||||||
|
x: The initial value, with the shape (batch, seq_len, emb_dim).
|
||||||
|
text_condition: The text_condition of the diffision model, with the shape (batch, seq_len, emb_dim).
|
||||||
|
speech_condition: The speech_condition of the diffision model, with the shape (batch, seq_len, emb_dim).
|
||||||
|
padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len).
|
||||||
|
guidance_scale: The scale of classifier-free guidance, a float or a tensor of shape (batch, 1, 1).
|
||||||
|
Retrun:
|
||||||
|
The prediction with the shape (batch, seq_len, emb_dim).
|
||||||
|
"""
|
||||||
|
if not torch.is_tensor(guidance_scale):
|
||||||
|
guidance_scale = torch.tensor(
|
||||||
|
guidance_scale, dtype=t.dtype, device=t.device
|
||||||
|
)
|
||||||
|
if self.distill:
|
||||||
|
return self.model_func(
|
||||||
|
t=t,
|
||||||
|
xt=x,
|
||||||
|
text_condition=text_condition,
|
||||||
|
speech_condition=speech_condition,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if (guidance_scale == 0.0).all():
|
||||||
|
return self.model_func(
|
||||||
|
t=t,
|
||||||
|
xt=x,
|
||||||
|
text_condition=text_condition,
|
||||||
|
speech_condition=speech_condition,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if t.dim() != 0:
|
||||||
|
t = torch.cat([t] * 2, dim=0)
|
||||||
|
|
||||||
|
x = torch.cat([x] * 2, dim=0)
|
||||||
|
padding_mask = torch.cat([padding_mask] * 2, dim=0)
|
||||||
|
|
||||||
|
text_condition = torch.cat(
|
||||||
|
[torch.zeros_like(text_condition), text_condition], dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
if t.dim() == 0:
|
||||||
|
if t > 0.5:
|
||||||
|
speech_condition = torch.cat(
|
||||||
|
[torch.zeros_like(speech_condition), speech_condition], dim=0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
guidance_scale = guidance_scale * 2
|
||||||
|
speech_condition = torch.cat(
|
||||||
|
[speech_condition, speech_condition], dim=0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert t.dim() > 0, t
|
||||||
|
larger_t_index = (t > 0.5).squeeze(1).squeeze(1)
|
||||||
|
zero_speech_condition = torch.cat(
|
||||||
|
[torch.zeros_like(speech_condition), speech_condition], dim=0
|
||||||
|
)
|
||||||
|
speech_condition = torch.cat(
|
||||||
|
[speech_condition, speech_condition], dim=0
|
||||||
|
)
|
||||||
|
speech_condition[larger_t_index] = zero_speech_condition[larger_t_index]
|
||||||
|
guidance_scale[~larger_t_index[: larger_t_index.size(0) // 2]] *= 2
|
||||||
|
|
||||||
|
data_uncond, data_cond = self.model_func(
|
||||||
|
t=t,
|
||||||
|
xt=x,
|
||||||
|
text_condition=text_condition,
|
||||||
|
speech_condition=speech_condition,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
**kwargs
|
||||||
|
).chunk(2, dim=0)
|
||||||
|
|
||||||
|
res = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class EulerSolver:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
distill: bool = False,
|
||||||
|
func_name: str = "forward_fm_decoder",
|
||||||
|
):
|
||||||
|
"""Construct a Euler Solver
|
||||||
|
Args:
|
||||||
|
model: The diffusion model.
|
||||||
|
distill: Whether it is distillation model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.model = DiffusionModel(model, distill=distill, func_name=func_name)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
text_condition: torch.Tensor,
|
||||||
|
speech_condition: torch.Tensor,
|
||||||
|
padding_mask: torch.Tensor,
|
||||||
|
num_step: int = 10,
|
||||||
|
guidance_scale: Union[float, torch.Tensor] = 0.0,
|
||||||
|
t_start: Union[float, torch.Tensor] = 0.0,
|
||||||
|
t_end: Union[float, torch.Tensor] = 1.0,
|
||||||
|
t_shift: float = 1.0,
|
||||||
|
**kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute the sample at time `t_end` by Euler Solver.
|
||||||
|
Args:
|
||||||
|
x: The initial value at time `t_start`, with the shape (batch, seq_len, emb_dim).
|
||||||
|
text_condition: The text condition of the diffision mode, with the shape (batch, seq_len, emb_dim).
|
||||||
|
speech_condition: The speech condition of the diffision model, with the shape (batch, seq_len, emb_dim).
|
||||||
|
padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len).
|
||||||
|
num_step: The number of ODE steps.
|
||||||
|
guidance_scale: The scale for classifier-free guidance, which is
|
||||||
|
a float or a tensor with the shape (batch, 1, 1).
|
||||||
|
t_start: the start timestep in the range of [0, 1],
|
||||||
|
which is a float or a tensor with the shape (batch, 1, 1).
|
||||||
|
t_end: the end time_step in the range of [0, 1],
|
||||||
|
which is a float or a tensor with the shape (batch, 1, 1).
|
||||||
|
t_shift: shift the t toward smaller numbers so that the sampling
|
||||||
|
will emphasize low SNR region. Should be in the range of (0, 1].
|
||||||
|
The shifting will be more significant when the number is smaller.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The approximated solution at time `t_end`.
|
||||||
|
"""
|
||||||
|
device = x.device
|
||||||
|
|
||||||
|
if torch.is_tensor(t_start) and t_start.dim() > 0:
|
||||||
|
timesteps = get_time_steps_batch(
|
||||||
|
t_start=t_start,
|
||||||
|
t_end=t_end,
|
||||||
|
num_step=num_step,
|
||||||
|
t_shift=t_shift,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
timesteps = get_time_steps(
|
||||||
|
t_start=t_start,
|
||||||
|
t_end=t_end,
|
||||||
|
num_step=num_step,
|
||||||
|
t_shift=t_shift,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
for step in range(num_step):
|
||||||
|
v = self.model(
|
||||||
|
t=timesteps[step],
|
||||||
|
x=x,
|
||||||
|
text_condition=text_condition,
|
||||||
|
speech_condition=speech_condition,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
x = x + v * (timesteps[step + 1] - timesteps[step])
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def get_time_steps(
|
||||||
|
t_start: float = 0.0,
|
||||||
|
t_end: float = 1.0,
|
||||||
|
num_step: int = 10,
|
||||||
|
t_shift: float = 1.0,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute the intermediate time steps for sampling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t_start: The starting time of the sampling (default is 0).
|
||||||
|
t_end: The starting time of the sampling (default is 1).
|
||||||
|
num_step: The number of sampling.
|
||||||
|
t_shift: shift the t toward smaller numbers so that the sampling
|
||||||
|
will emphasize low SNR region. Should be in the range of (0, 1].
|
||||||
|
The shifting will be more significant when the number is smaller.
|
||||||
|
device: A torch device.
|
||||||
|
Returns:
|
||||||
|
The time step with the shape (num_step + 1,).
|
||||||
|
"""
|
||||||
|
|
||||||
|
timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device)
|
||||||
|
|
||||||
|
timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
|
||||||
|
|
||||||
|
return timesteps
|
||||||
|
|
||||||
|
|
||||||
|
def get_time_steps_batch(
|
||||||
|
t_start: torch.Tensor,
|
||||||
|
t_end: torch.Tensor,
|
||||||
|
num_step: int = 10,
|
||||||
|
t_shift: float = 1.0,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute the intermediate time steps for sampling in the batch mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t_start: The starting time of the sampling (default is 0), with the shape (batch, 1, 1).
|
||||||
|
t_end: The starting time of the sampling (default is 1), with the shape (batch, 1, 1).
|
||||||
|
num_step: The number of sampling.
|
||||||
|
t_shift: shift the t toward smaller numbers so that the sampling
|
||||||
|
will emphasize low SNR region. Should be in the range of (0, 1].
|
||||||
|
The shifting will be more significant when the number is smaller.
|
||||||
|
device: A torch device.
|
||||||
|
Returns:
|
||||||
|
The time step with the shape (num_step + 1, N, 1, 1).
|
||||||
|
"""
|
||||||
|
while t_start.dim() > 1 and t_start.size(-1) == 1:
|
||||||
|
t_start = t_start.squeeze(-1)
|
||||||
|
while t_end.dim() > 1 and t_end.size(-1) == 1:
|
||||||
|
t_end = t_end.squeeze(-1)
|
||||||
|
assert t_start.dim() == t_end.dim() == 1
|
||||||
|
|
||||||
|
timesteps_shape = (num_step + 1, t_start.size(0))
|
||||||
|
timesteps = torch.zeros(timesteps_shape, device=device)
|
||||||
|
|
||||||
|
for i in range(t_start.size(0)):
|
||||||
|
timesteps[:, i] = torch.linspace(t_start[i], t_end[i], steps=num_step + 1)
|
||||||
|
|
||||||
|
timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
|
||||||
|
|
||||||
|
return timesteps.unsqueeze(-1).unsqueeze(-1)
|
570
egs/zipvoice/zipvoice/tokenizer.py
Normal file
570
egs/zipvoice/zipvoice/tokenizer.py
Normal file
@ -0,0 +1,570 @@
|
|||||||
|
# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao
|
||||||
|
# Han Zhu,
|
||||||
|
# Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import unicodedata
|
||||||
|
from functools import reduce
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import cn2an
|
||||||
|
import inflect
|
||||||
|
import jieba
|
||||||
|
from pypinyin import Style, lazy_pinyin
|
||||||
|
from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
|
||||||
|
|
||||||
|
try:
|
||||||
|
from piper_phonemize import phonemize_espeak
|
||||||
|
except Exception as ex:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{ex}\nPlease run\n"
|
||||||
|
"pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html"
|
||||||
|
)
|
||||||
|
|
||||||
|
_inflect = inflect.engine()
|
||||||
|
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
||||||
|
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
||||||
|
_percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
|
||||||
|
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
||||||
|
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
||||||
|
_fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
|
||||||
|
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
||||||
|
_number_re = re.compile(r"[0-9]+")
|
||||||
|
|
||||||
|
# List of (regular expression, replacement) pairs for abbreviations:
|
||||||
|
_abbreviations = [
|
||||||
|
(re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
|
||||||
|
for x in [
|
||||||
|
("mrs", "misess"),
|
||||||
|
("mr", "mister"),
|
||||||
|
("dr", "doctor"),
|
||||||
|
("st", "saint"),
|
||||||
|
("co", "company"),
|
||||||
|
("jr", "junior"),
|
||||||
|
("maj", "major"),
|
||||||
|
("gen", "general"),
|
||||||
|
("drs", "doctors"),
|
||||||
|
("rev", "reverend"),
|
||||||
|
("lt", "lieutenant"),
|
||||||
|
("hon", "honorable"),
|
||||||
|
("sgt", "sergeant"),
|
||||||
|
("capt", "captain"),
|
||||||
|
("esq", "esquire"),
|
||||||
|
("ltd", "limited"),
|
||||||
|
("col", "colonel"),
|
||||||
|
("ft", "fort"),
|
||||||
|
("etc", "et cetera"),
|
||||||
|
("btw", "by the way"),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def intersperse(sequence, item=0):
|
||||||
|
result = [item] * (len(sequence) * 2 + 1)
|
||||||
|
result[1::2] = sequence
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def expand_abbreviations(text):
|
||||||
|
for regex, replacement in _abbreviations:
|
||||||
|
text = re.sub(regex, replacement, text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_commas(m):
|
||||||
|
return m.group(1).replace(",", "")
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_decimal_point(m):
|
||||||
|
return m.group(1).replace(".", " point ")
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_percent(m):
|
||||||
|
return m.group(1).replace("%", " percent ")
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_dollars(m):
|
||||||
|
match = m.group(1)
|
||||||
|
parts = match.split(".")
|
||||||
|
if len(parts) > 2:
|
||||||
|
return " " + match + " dollars " # Unexpected format
|
||||||
|
dollars = int(parts[0]) if parts[0] else 0
|
||||||
|
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||||
|
if dollars and cents:
|
||||||
|
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||||
|
cent_unit = "cent" if cents == 1 else "cents"
|
||||||
|
return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
|
||||||
|
elif dollars:
|
||||||
|
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||||
|
return " %s %s " % (dollars, dollar_unit)
|
||||||
|
elif cents:
|
||||||
|
cent_unit = "cent" if cents == 1 else "cents"
|
||||||
|
return " %s %s " % (cents, cent_unit)
|
||||||
|
else:
|
||||||
|
return " zero dollars "
|
||||||
|
|
||||||
|
|
||||||
|
def fraction_to_words(numerator, denominator):
|
||||||
|
if numerator == 1 and denominator == 2:
|
||||||
|
return " one half "
|
||||||
|
if numerator == 1 and denominator == 4:
|
||||||
|
return " one quarter "
|
||||||
|
if denominator == 2:
|
||||||
|
return " " + _inflect.number_to_words(numerator) + " halves "
|
||||||
|
if denominator == 4:
|
||||||
|
return " " + _inflect.number_to_words(numerator) + " quarters "
|
||||||
|
return (
|
||||||
|
" "
|
||||||
|
+ _inflect.number_to_words(numerator)
|
||||||
|
+ " "
|
||||||
|
+ _inflect.ordinal(_inflect.number_to_words(denominator))
|
||||||
|
+ " "
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_fraction(m):
|
||||||
|
numerator = int(m.group(1))
|
||||||
|
denominator = int(m.group(2))
|
||||||
|
return fraction_to_words(numerator, denominator)
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_ordinal(m):
|
||||||
|
return " " + _inflect.number_to_words(m.group(0)) + " "
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_number(m):
|
||||||
|
num = int(m.group(0))
|
||||||
|
if num > 1000 and num < 3000:
|
||||||
|
if num == 2000:
|
||||||
|
return " two thousand "
|
||||||
|
elif num > 2000 and num < 2010:
|
||||||
|
return " two thousand " + _inflect.number_to_words(num % 100) + " "
|
||||||
|
elif num % 100 == 0:
|
||||||
|
return " " + _inflect.number_to_words(num // 100) + " hundred "
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
" "
|
||||||
|
+ _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(
|
||||||
|
", ", " "
|
||||||
|
)
|
||||||
|
+ " "
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return " " + _inflect.number_to_words(num, andword="") + " "
|
||||||
|
|
||||||
|
|
||||||
|
# Normalize numbers pronunciation
|
||||||
|
def normalize_numbers(text):
|
||||||
|
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||||
|
text = re.sub(_pounds_re, r"\1 pounds", text)
|
||||||
|
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||||
|
text = re.sub(_fraction_re, _expand_fraction, text)
|
||||||
|
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||||
|
text = re.sub(_percent_number_re, _expand_percent, text)
|
||||||
|
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||||
|
text = re.sub(_number_re, _expand_number, text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# Convert numbers to Chinese pronunciation
|
||||||
|
def number_to_chinese(text):
|
||||||
|
text = cn2an.transform(text, "an2cn")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def map_punctuations(text):
|
||||||
|
text = text.replace(",", ",")
|
||||||
|
text = text.replace("。", ".")
|
||||||
|
text = text.replace("!", "!")
|
||||||
|
text = text.replace("?", "?")
|
||||||
|
text = text.replace(";", ";")
|
||||||
|
text = text.replace(":", ":")
|
||||||
|
text = text.replace("、", ",")
|
||||||
|
text = text.replace("‘", "'")
|
||||||
|
text = text.replace("“", '"')
|
||||||
|
text = text.replace("”", '"')
|
||||||
|
text = text.replace("’", "'")
|
||||||
|
text = text.replace("⋯", "…")
|
||||||
|
text = text.replace("···", "…")
|
||||||
|
text = text.replace("・・・", "…")
|
||||||
|
text = text.replace("...", "…")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def is_chinese(char):
|
||||||
|
if char >= "\u4e00" and char <= "\u9fa5":
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_alphabet(char):
|
||||||
|
if (char >= "\u0041" and char <= "\u005a") or (
|
||||||
|
char >= "\u0061" and char <= "\u007a"
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_hangul(char):
|
||||||
|
letters = unicodedata.normalize("NFD", char)
|
||||||
|
return all(
|
||||||
|
["\u1100" <= c <= "\u11ff" or "\u3131" <= c <= "\u318e" for c in letters]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_japanese(char):
|
||||||
|
return any(
|
||||||
|
[
|
||||||
|
start <= char <= end
|
||||||
|
for start, end in [
|
||||||
|
("\u3041", "\u3096"),
|
||||||
|
("\u30a0", "\u30ff"),
|
||||||
|
("\uff5f", "\uff9f"),
|
||||||
|
("\u31f0", "\u31ff"),
|
||||||
|
("\u3220", "\u3243"),
|
||||||
|
("\u3280", "\u337f"),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_segment(text: str) -> List[str]:
|
||||||
|
# sentence --> [ch_part, en_part, ch_part, ...]
|
||||||
|
# example :
|
||||||
|
# input : 我们是小米人,是吗? Yes I think so!霍...啦啦啦
|
||||||
|
# output : [('我们是小米人,是吗? ', 'zh'), ('Yes I think so!', 'en'), ('霍...啦啦啦', 'zh')]
|
||||||
|
segments = []
|
||||||
|
types = []
|
||||||
|
flag = 0
|
||||||
|
temp_seg = ""
|
||||||
|
temp_lang = ""
|
||||||
|
|
||||||
|
for i, ch in enumerate(text):
|
||||||
|
if is_chinese(ch):
|
||||||
|
types.append("zh")
|
||||||
|
elif is_alphabet(ch):
|
||||||
|
types.append("en")
|
||||||
|
else:
|
||||||
|
types.append("other")
|
||||||
|
|
||||||
|
assert len(types) == len(text)
|
||||||
|
|
||||||
|
for i in range(len(types)):
|
||||||
|
# find the first char of the seg
|
||||||
|
if flag == 0:
|
||||||
|
temp_seg += text[i]
|
||||||
|
temp_lang = types[i]
|
||||||
|
flag = 1
|
||||||
|
else:
|
||||||
|
if temp_lang == "other":
|
||||||
|
if types[i] == temp_lang:
|
||||||
|
temp_seg += text[i]
|
||||||
|
else:
|
||||||
|
temp_seg += text[i]
|
||||||
|
temp_lang = types[i]
|
||||||
|
else:
|
||||||
|
if types[i] == temp_lang:
|
||||||
|
temp_seg += text[i]
|
||||||
|
elif types[i] == "other":
|
||||||
|
temp_seg += text[i]
|
||||||
|
else:
|
||||||
|
segments.append((temp_seg, temp_lang))
|
||||||
|
temp_seg = text[i]
|
||||||
|
temp_lang = types[i]
|
||||||
|
flag = 1
|
||||||
|
|
||||||
|
segments.append((temp_seg, temp_lang))
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(text: str) -> str:
|
||||||
|
text = map_punctuations(text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_ZH(text: str) -> List[str]:
|
||||||
|
try:
|
||||||
|
text = number_to_chinese(text)
|
||||||
|
segs = list(jieba.cut(text))
|
||||||
|
full = lazy_pinyin(
|
||||||
|
segs, style=Style.TONE3, tone_sandhi=True, neutral_tone_with_five=True
|
||||||
|
)
|
||||||
|
phones = []
|
||||||
|
for x in full:
|
||||||
|
# valid pinyin (in tone3 style) is alphabet + 1 number in [1-5].
|
||||||
|
if not (x[0:-1].isalpha() and x[-1] in ("1", "2", "3", "4", "5")):
|
||||||
|
phones.append(x)
|
||||||
|
continue
|
||||||
|
initial = to_initials(x, strict=False)
|
||||||
|
# don't want to share tokens with espeak tokens, so use tone3 style
|
||||||
|
final = to_finals_tone3(x, strict=False, neutral_tone_with_five=True)
|
||||||
|
if initial != "":
|
||||||
|
# don't want to share tokens with espeak tokens, so add a '0' after each initial
|
||||||
|
phones.append(initial + "0")
|
||||||
|
if final != "":
|
||||||
|
phones.append(final)
|
||||||
|
return phones
|
||||||
|
except:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_EN(text: str) -> List[str]:
|
||||||
|
try:
|
||||||
|
text = expand_abbreviations(text)
|
||||||
|
text = normalize_numbers(text)
|
||||||
|
tokens = phonemize_espeak(text, "en-us")
|
||||||
|
tokens = reduce(lambda x, y: x + y, tokens)
|
||||||
|
return tokens
|
||||||
|
except:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerEmilia(object):
|
||||||
|
def __init__(self, token_file: Optional[str] = None, token_type="phone"):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokens: the file that contains information that maps tokens to ids,
|
||||||
|
which is a text file with '{token} {token_id}' per line.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
token_type == "phone"
|
||||||
|
), f"Only support phone tokenizer for Emilia, but get {token_type}."
|
||||||
|
self.has_tokens = False
|
||||||
|
if token_file is None:
|
||||||
|
logging.debug(
|
||||||
|
"Initialize Tokenizer without tokens file, will fail when map to ids."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
self.token2id: Dict[str, int] = {}
|
||||||
|
with open(token_file, "r", encoding="utf-8") as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
info = line.rstrip().split("\t")
|
||||||
|
token, id = info[0], int(info[1])
|
||||||
|
assert token not in self.token2id, token
|
||||||
|
self.token2id[token] = id
|
||||||
|
self.pad_id = self.token2id["_"] # padding
|
||||||
|
|
||||||
|
self.vocab_size = len(self.token2id)
|
||||||
|
self.has_tokens = True
|
||||||
|
|
||||||
|
def texts_to_token_ids(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
) -> List[List[int]]:
|
||||||
|
return self.tokens_to_token_ids(self.texts_to_tokens(texts))
|
||||||
|
|
||||||
|
def texts_to_tokens(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
) -> List[List[str]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
A list of transcripts.
|
||||||
|
Returns:
|
||||||
|
Return a list of a list of tokens [utterance][token]
|
||||||
|
"""
|
||||||
|
for i in range(len(texts)):
|
||||||
|
# Text normalization
|
||||||
|
texts[i] = preprocess(texts[i])
|
||||||
|
|
||||||
|
phoneme_list = []
|
||||||
|
for text in texts:
|
||||||
|
# now only en and ch
|
||||||
|
segments = get_segment(text)
|
||||||
|
all_phoneme = []
|
||||||
|
for index in range(len(segments)):
|
||||||
|
seg = segments[index]
|
||||||
|
if seg[1] == "zh":
|
||||||
|
phoneme = tokenize_ZH(seg[0])
|
||||||
|
else:
|
||||||
|
if seg[1] != "en":
|
||||||
|
logging.warning(
|
||||||
|
f"The lang should be en, given {seg[1]}, skipping segment : {seg}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
phoneme = tokenize_EN(seg[0])
|
||||||
|
all_phoneme += phoneme
|
||||||
|
phoneme_list.append(all_phoneme)
|
||||||
|
return phoneme_list
|
||||||
|
|
||||||
|
def tokens_to_token_ids(
|
||||||
|
self,
|
||||||
|
tokens: List[List[str]],
|
||||||
|
intersperse_blank: bool = False,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokens_list:
|
||||||
|
A list of token list, each corresponding to one utterance.
|
||||||
|
intersperse_blank:
|
||||||
|
Whether to intersperse blanks in the token sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list of token id list [utterance][token_id]
|
||||||
|
"""
|
||||||
|
assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
|
||||||
|
token_ids = []
|
||||||
|
|
||||||
|
for tks in tokens:
|
||||||
|
ids = []
|
||||||
|
for t in tks:
|
||||||
|
if t not in self.token2id:
|
||||||
|
logging.warning(f"Skip OOV {t}")
|
||||||
|
continue
|
||||||
|
ids.append(self.token2id[t])
|
||||||
|
|
||||||
|
if intersperse_blank:
|
||||||
|
ids = intersperse(ids, self.pad_id)
|
||||||
|
|
||||||
|
token_ids.append(ids)
|
||||||
|
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerLibriTTS(object):
|
||||||
|
def __init__(self, token_file: str, token_type: str):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
type: the type of tokenizer, e.g., bpe, char, phone.
|
||||||
|
tokens: the file that contains information that maps tokens to ids,
|
||||||
|
which is a text file with '{token} {token_id}' per line if type is
|
||||||
|
char or phone, otherwise it is a bpe_model file.
|
||||||
|
"""
|
||||||
|
self.type = token_type
|
||||||
|
assert token_type in ["bpe", "char", "phone"]
|
||||||
|
# Parse token file
|
||||||
|
|
||||||
|
if token_type == "bpe":
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
self.sp = spm.SentencePieceProcessor()
|
||||||
|
self.sp.load(token_file)
|
||||||
|
self.pad_id = self.sp.piece_to_id("<pad>")
|
||||||
|
self.vocab_size = self.sp.get_piece_size()
|
||||||
|
else:
|
||||||
|
self.token2id: Dict[str, int] = {}
|
||||||
|
with open(token_file, "r", encoding="utf-8") as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
info = line.rstrip().split("\t")
|
||||||
|
token, id = info[0], int(info[1])
|
||||||
|
assert token not in self.token2id, token
|
||||||
|
self.token2id[token] = id
|
||||||
|
self.pad_id = self.token2id["_"] # padding
|
||||||
|
self.vocab_size = len(self.token2id)
|
||||||
|
try:
|
||||||
|
from tacotron_cleaner.cleaners import custom_english_cleaners as cleaner
|
||||||
|
except Exception as ex:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{ex}\nPlease run\n"
|
||||||
|
"pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/"
|
||||||
|
)
|
||||||
|
self.cleaner = cleaner
|
||||||
|
|
||||||
|
def texts_to_token_ids(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
lang: str = "en-us",
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
A list of transcripts.
|
||||||
|
intersperse_blank:
|
||||||
|
Whether to intersperse blanks in the token sequence.
|
||||||
|
Used when alignment is from MAS.
|
||||||
|
lang:
|
||||||
|
Language argument passed to phonemize_espeak().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list of token id list [utterance][token_id]
|
||||||
|
"""
|
||||||
|
for i in range(len(texts)):
|
||||||
|
# Text normalization
|
||||||
|
texts[i] = self.cleaner(texts[i])
|
||||||
|
|
||||||
|
if self.type == "bpe":
|
||||||
|
token_ids_list = self.sp.encode(texts)
|
||||||
|
|
||||||
|
elif self.type == "phone":
|
||||||
|
token_ids_list = []
|
||||||
|
for text in texts:
|
||||||
|
tokens_list = phonemize_espeak(text.lower(), lang)
|
||||||
|
tokens = []
|
||||||
|
for t in tokens_list:
|
||||||
|
tokens.extend(t)
|
||||||
|
token_ids = []
|
||||||
|
for t in tokens:
|
||||||
|
if t not in self.token2id:
|
||||||
|
logging.warning(f"Skip OOV {t}")
|
||||||
|
continue
|
||||||
|
token_ids.append(self.token2id[t])
|
||||||
|
|
||||||
|
token_ids_list.append(token_ids)
|
||||||
|
else:
|
||||||
|
token_ids_list = []
|
||||||
|
for text in texts:
|
||||||
|
token_ids = []
|
||||||
|
for t in text:
|
||||||
|
if t not in self.token2id:
|
||||||
|
logging.warning(f"Skip OOV {t}")
|
||||||
|
continue
|
||||||
|
token_ids.append(self.token2id[t])
|
||||||
|
|
||||||
|
token_ids_list.append(token_ids)
|
||||||
|
|
||||||
|
return token_ids_list
|
||||||
|
|
||||||
|
def tokens_to_token_ids(
|
||||||
|
self,
|
||||||
|
tokens_list: List[str],
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokens_list:
|
||||||
|
A list of token list, each corresponding to one utterance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a list of token id list [utterance][token_id]
|
||||||
|
"""
|
||||||
|
token_ids_list = []
|
||||||
|
|
||||||
|
for tokens in tokens_list:
|
||||||
|
token_ids = []
|
||||||
|
for t in tokens:
|
||||||
|
if t not in self.token2id:
|
||||||
|
logging.warning(f"Skip OOV {t}")
|
||||||
|
continue
|
||||||
|
token_ids.append(self.token2id[t])
|
||||||
|
|
||||||
|
token_ids_list.append(token_ids)
|
||||||
|
|
||||||
|
return token_ids_list
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
text = "我们是5年小米人,是吗? Yes I think so! mr king, 5 years, from 2019 to 2024. 霍...啦啦啦超过90%的人咯...?!9204"
|
||||||
|
tokenizer = Tokenizer()
|
||||||
|
tokens = tokenizer.texts_to_tokens([text])
|
||||||
|
print(f"tokens : {tokens}")
|
||||||
|
tokens2 = "|".join(tokens[0])
|
||||||
|
print(f"tokens2 : {tokens2}")
|
||||||
|
tokens2 = tokens2.split("|")
|
||||||
|
assert tokens[0] == tokens2
|
1043
egs/zipvoice/zipvoice/train_distill.py
Normal file
1043
egs/zipvoice/zipvoice/train_distill.py
Normal file
File diff suppressed because it is too large
Load Diff
1108
egs/zipvoice/zipvoice/train_flow.py
Normal file
1108
egs/zipvoice/zipvoice/train_flow.py
Normal file
File diff suppressed because it is too large
Load Diff
456
egs/zipvoice/zipvoice/tts_datamodule.py
Normal file
456
egs/zipvoice/zipvoice/tts_datamodule.py
Normal file
@ -0,0 +1,456 @@
|
|||||||
|
# Copyright 2021 Piotr Żelasko
|
||||||
|
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||||
|
# Zengwei Yao,
|
||||||
|
# Zengrui Jin,
|
||||||
|
# Han Zhu,
|
||||||
|
# Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from feature import TorchAudioFbank, TorchAudioFbankConfig
|
||||||
|
from lhotse import CutSet, load_manifest_lazy, validate
|
||||||
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
|
DynamicBucketingSampler,
|
||||||
|
PrecomputedFeatures,
|
||||||
|
SimpleCutSampler,
|
||||||
|
)
|
||||||
|
from lhotse.dataset.collation import collate_audio
|
||||||
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
|
BatchIO,
|
||||||
|
OnTheFlyFeatures,
|
||||||
|
)
|
||||||
|
from lhotse.utils import fix_random_seed, ifnone
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class _SeedWorkers:
|
||||||
|
def __init__(self, seed: int):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, worker_id: int):
|
||||||
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
|
SAMPLING_RATE = 24000
|
||||||
|
|
||||||
|
|
||||||
|
class TtsDataModule:
|
||||||
|
"""
|
||||||
|
DataModule for tts experiments.
|
||||||
|
It assumes there is always one train and valid dataloader,
|
||||||
|
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||||
|
and test-other).
|
||||||
|
|
||||||
|
It contains all the common data pipeline modules used in ASR
|
||||||
|
experiments, e.g.:
|
||||||
|
- dynamic batch size,
|
||||||
|
- bucketing samplers,
|
||||||
|
- cut concatenation,
|
||||||
|
- on-the-fly feature extraction
|
||||||
|
|
||||||
|
This class should be derived for specific corpora used in ASR tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args: argparse.Namespace):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group(
|
||||||
|
title="TTS data related options",
|
||||||
|
description="These options are used for the preparation of "
|
||||||
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
|
"augmentations, etc.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--manifest-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/fbank_emilia"),
|
||||||
|
help="Path to directory with train/valid/test cuts.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--max-duration",
|
||||||
|
type=int,
|
||||||
|
default=200.0,
|
||||||
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--bucketing-sampler",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, the batches will come from buckets of "
|
||||||
|
"similar duration (saves padding frames).",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-buckets",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
|
"(you might want to increase it for larger datasets).",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--on-the-fly-feats",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, use on-the-fly cut mixing and feature "
|
||||||
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
|
"if available.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--shuffle",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled (=default), the examples will be "
|
||||||
|
"shuffled for each epoch.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--drop-last",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to drop last batch. Used by sampler.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--return-cuts",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, each batch will have the "
|
||||||
|
"field: batch['cut'] with the cuts that "
|
||||||
|
"were used to construct it.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--input-strategy",
|
||||||
|
type=str,
|
||||||
|
default="PrecomputedFeatures",
|
||||||
|
help="AudioSamples or PrecomputedFeatures",
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_dataloaders(
|
||||||
|
self,
|
||||||
|
cuts_train: CutSet,
|
||||||
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> DataLoader:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cuts_train:
|
||||||
|
CutSet for training.
|
||||||
|
sampler_state_dict:
|
||||||
|
The state dict for the training sampler.
|
||||||
|
"""
|
||||||
|
logging.info("About to create train dataset")
|
||||||
|
train = SpeechSynthesisDataset(
|
||||||
|
return_text=True,
|
||||||
|
return_tokens=True,
|
||||||
|
return_spk_ids=True,
|
||||||
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
sampling_rate = SAMPLING_RATE
|
||||||
|
config = TorchAudioFbankConfig(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
n_mels=100,
|
||||||
|
n_fft=1024,
|
||||||
|
hop_length=256,
|
||||||
|
)
|
||||||
|
train = SpeechSynthesisDataset(
|
||||||
|
return_text=True,
|
||||||
|
return_tokens=True,
|
||||||
|
return_spk_ids=True,
|
||||||
|
feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.bucketing_sampler:
|
||||||
|
logging.info("Using DynamicBucketingSampler.")
|
||||||
|
train_sampler = DynamicBucketingSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
buffer_size=self.args.num_buckets * 2000,
|
||||||
|
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||||
|
drop_last=self.args.drop_last,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Using SimpleCutSampler.")
|
||||||
|
train_sampler = SimpleCutSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
)
|
||||||
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
|
if sampler_state_dict is not None:
|
||||||
|
logging.info("Loading sampler state dict")
|
||||||
|
train_sampler.load_state_dict(sampler_state_dict)
|
||||||
|
|
||||||
|
# 'seed' is derived from the current random state, which will have
|
||||||
|
# previously been set in the main process.
|
||||||
|
seed = torch.randint(0, 100000, ()).item()
|
||||||
|
worker_init_fn = _SeedWorkers(seed)
|
||||||
|
|
||||||
|
train_dl = DataLoader(
|
||||||
|
train,
|
||||||
|
sampler=train_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
persistent_workers=False,
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dl
|
||||||
|
|
||||||
|
def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
|
logging.info("About to create dev dataset")
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
sampling_rate = SAMPLING_RATE
|
||||||
|
config = TorchAudioFbankConfig(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
n_mels=100,
|
||||||
|
n_fft=1024,
|
||||||
|
hop_length=256,
|
||||||
|
)
|
||||||
|
validate = SpeechSynthesisDataset(
|
||||||
|
return_text=True,
|
||||||
|
return_tokens=True,
|
||||||
|
return_spk_ids=True,
|
||||||
|
feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
validate = SpeechSynthesisDataset(
|
||||||
|
return_text=True,
|
||||||
|
return_tokens=True,
|
||||||
|
return_spk_ids=True,
|
||||||
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
dev_sampler = DynamicBucketingSampler(
|
||||||
|
cuts_valid,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
logging.info("About to create valid dataloader")
|
||||||
|
dev_dl = DataLoader(
|
||||||
|
validate,
|
||||||
|
sampler=dev_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=2,
|
||||||
|
persistent_workers=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return dev_dl
|
||||||
|
|
||||||
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
|
logging.info("About to create test dataset")
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
sampling_rate = SAMPLING_RATE
|
||||||
|
config = TorchAudioFbankConfig(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
n_mels=100,
|
||||||
|
n_fft=1024,
|
||||||
|
hop_length=256,
|
||||||
|
)
|
||||||
|
test = SpeechSynthesisDataset(
|
||||||
|
return_text=True,
|
||||||
|
return_tokens=True,
|
||||||
|
return_spk_ids=True,
|
||||||
|
feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
return_audio=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
test = SpeechSynthesisDataset(
|
||||||
|
return_text=True,
|
||||||
|
return_tokens=True,
|
||||||
|
return_spk_ids=True,
|
||||||
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
return_audio=True,
|
||||||
|
)
|
||||||
|
test_sampler = DynamicBucketingSampler(
|
||||||
|
cuts,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
logging.info("About to create test dataloader")
|
||||||
|
test_dl = DataLoader(
|
||||||
|
test,
|
||||||
|
batch_size=None,
|
||||||
|
sampler=test_sampler,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
)
|
||||||
|
return test_dl
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_emilia_EN_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train the EN subset")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_EN.jsonl.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_emilia_ZH_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train the ZH subset")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_ZH.jsonl.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_emilia_EN_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get dev the EN subset")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "emilia_cuts_EN-dev.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_emilia_ZH_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get dev the ZH subset")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "emilia_cuts_ZH-dev.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_libritts_cuts(self) -> CutSet:
|
||||||
|
logging.info(
|
||||||
|
"About to get the shuffled train-clean-100, \
|
||||||
|
train-clean-360 and train-other-500 cuts"
|
||||||
|
)
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_with_tokens_train-all-shuf.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_libritts_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get dev-clean cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libritts_cuts_with_tokens_dev-clean.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechSynthesisDataset(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
The PyTorch Dataset for the speech synthesis task.
|
||||||
|
Each item in this dataset is a dict of:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
{
|
||||||
|
'audio': (B x NumSamples) float tensor
|
||||||
|
'features': (B x NumFrames x NumFeatures) float tensor
|
||||||
|
'audio_lens': (B, ) int tensor
|
||||||
|
'features_lens': (B, ) int tensor
|
||||||
|
'text': List[str] of len B # when return_text=True
|
||||||
|
'tokens': List[List[str]] # when return_tokens=True
|
||||||
|
'speakers': List[str] of len B # when return_spk_ids=True
|
||||||
|
'cut': List of Cuts # when return_cuts=True
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
|
||||||
|
feature_input_strategy: BatchIO = PrecomputedFeatures(),
|
||||||
|
feature_transforms: Union[Sequence[Callable], Callable] = None,
|
||||||
|
return_text: bool = True,
|
||||||
|
return_tokens: bool = False,
|
||||||
|
return_spk_ids: bool = False,
|
||||||
|
return_cuts: bool = False,
|
||||||
|
return_audio: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cut_transforms = ifnone(cut_transforms, [])
|
||||||
|
self.feature_input_strategy = feature_input_strategy
|
||||||
|
|
||||||
|
self.return_text = return_text
|
||||||
|
self.return_tokens = return_tokens
|
||||||
|
self.return_spk_ids = return_spk_ids
|
||||||
|
self.return_cuts = return_cuts
|
||||||
|
self.return_audio = return_audio
|
||||||
|
|
||||||
|
if feature_transforms is None:
|
||||||
|
feature_transforms = []
|
||||||
|
elif not isinstance(feature_transforms, Sequence):
|
||||||
|
feature_transforms = [feature_transforms]
|
||||||
|
|
||||||
|
assert all(
|
||||||
|
isinstance(transform, Callable) for transform in feature_transforms
|
||||||
|
), "Feature transforms must be Callable"
|
||||||
|
self.feature_transforms = feature_transforms
|
||||||
|
|
||||||
|
def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
|
||||||
|
validate_for_tts(cuts)
|
||||||
|
|
||||||
|
for transform in self.cut_transforms:
|
||||||
|
cuts = transform(cuts)
|
||||||
|
|
||||||
|
features, features_lens = self.feature_input_strategy(cuts)
|
||||||
|
|
||||||
|
for transform in self.feature_transforms:
|
||||||
|
features = transform(features)
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"features": features,
|
||||||
|
"features_lens": features_lens,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.return_audio:
|
||||||
|
audio, audio_lens = collate_audio(cuts)
|
||||||
|
batch["audio"] = audio
|
||||||
|
batch["audio_lens"] = audio_lens
|
||||||
|
|
||||||
|
if self.return_text:
|
||||||
|
# use normalized text
|
||||||
|
text = [cut.supervisions[0].normalized_text for cut in cuts]
|
||||||
|
batch["text"] = text
|
||||||
|
|
||||||
|
if self.return_tokens:
|
||||||
|
tokens = [cut.tokens for cut in cuts]
|
||||||
|
batch["tokens"] = tokens
|
||||||
|
|
||||||
|
if self.return_spk_ids:
|
||||||
|
batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
|
||||||
|
|
||||||
|
if self.return_cuts:
|
||||||
|
batch["cut"] = [cut for cut in cuts]
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def validate_for_tts(cuts: CutSet) -> None:
|
||||||
|
validate(cuts)
|
||||||
|
for cut in cuts:
|
||||||
|
assert (
|
||||||
|
len(cut.supervisions) == 1
|
||||||
|
), "Only the Cuts with single supervision are supported."
|
219
egs/zipvoice/zipvoice/utils.py
Normal file
219
egs/zipvoice/zipvoice/utils.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
|
||||||
|
class AttributeDict(dict):
|
||||||
|
def __getattr__(self, key):
|
||||||
|
if key in self:
|
||||||
|
return self[key]
|
||||||
|
raise AttributeError(f"No such attribute '{key}'")
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
self[key] = value
|
||||||
|
|
||||||
|
def __delattr__(self, key):
|
||||||
|
if key in self:
|
||||||
|
del self[key]
|
||||||
|
return
|
||||||
|
raise AttributeError(f"No such attribute '{key}'")
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_input(
|
||||||
|
params: AttributeDict,
|
||||||
|
batch: dict,
|
||||||
|
device: torch.device,
|
||||||
|
tokenizer: Optional[Any] = None,
|
||||||
|
return_tokens: bool = False,
|
||||||
|
return_feature: bool = False,
|
||||||
|
return_audio: bool = False,
|
||||||
|
return_prompt: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Parse the features and targets of the current batch.
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
sp:
|
||||||
|
Used to convert text to bpe tokens.
|
||||||
|
device:
|
||||||
|
The device of Tensor.
|
||||||
|
"""
|
||||||
|
return_list = []
|
||||||
|
|
||||||
|
if return_tokens:
|
||||||
|
assert tokenizer is not None
|
||||||
|
|
||||||
|
if params.token_type == "phone":
|
||||||
|
tokens = tokenizer.tokens_to_token_ids(batch["tokens"])
|
||||||
|
else:
|
||||||
|
tokens = tokenizer.texts_to_token_ids(batch["text"])
|
||||||
|
return_list += [tokens]
|
||||||
|
|
||||||
|
if return_feature:
|
||||||
|
features = batch["features"].to(device)
|
||||||
|
features_lens = batch["features_lens"].to(device)
|
||||||
|
return_list += [features * params.feat_scale, features_lens]
|
||||||
|
|
||||||
|
if return_audio:
|
||||||
|
return_list += [batch["audio"], batch["audio_lens"]]
|
||||||
|
|
||||||
|
if return_prompt:
|
||||||
|
if return_tokens:
|
||||||
|
if params.token_type == "phone":
|
||||||
|
prompt_tokens = tokenizer.tokens_to_token_ids(batch["prompt"]["tokens"])
|
||||||
|
else:
|
||||||
|
prompt_tokens = tokenizer.texts_to_token_ids(batch["prompt"]["text"])
|
||||||
|
return_list += [prompt_tokens]
|
||||||
|
if return_feature:
|
||||||
|
prompt_features = batch["prompt"]["features"].to(device)
|
||||||
|
prompt_features_lens = batch["prompt"]["features_lens"].to(device)
|
||||||
|
return_list += [prompt_features * params.feat_scale, prompt_features_lens]
|
||||||
|
if return_audio:
|
||||||
|
return_list += [batch["prompt"]["audio"], batch["prompt"]["audio_lens"]]
|
||||||
|
|
||||||
|
return return_list
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_avg_tokens_durations(features_lens, tokens_lens):
|
||||||
|
tokens_durations = []
|
||||||
|
for i in range(len(features_lens)):
|
||||||
|
utt_duration = features_lens[i]
|
||||||
|
avg_token_duration = utt_duration // tokens_lens[i]
|
||||||
|
tokens_durations.append([avg_token_duration] * tokens_lens[i])
|
||||||
|
return tokens_durations
|
||||||
|
|
||||||
|
|
||||||
|
def pad_labels(y: List[List[int]], pad_id: int, device: torch.device):
|
||||||
|
"""
|
||||||
|
Pad the transcripts to the same length with zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: the transcripts, which is a list of a list
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a Tensor of padded transcripts.
|
||||||
|
"""
|
||||||
|
y = [l + [pad_id] for l in y]
|
||||||
|
length = max([len(l) for l in y])
|
||||||
|
y = [l + [pad_id] * (length - len(l)) for l in y]
|
||||||
|
return torch.tensor(y, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokens_index(durations: List[List[int]], num_frames: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Gets position in the transcript for each frame, i.e. the position
|
||||||
|
in the symbol-sequence to look up.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
durations:
|
||||||
|
Duration of each token in transcripts.
|
||||||
|
num_frames:
|
||||||
|
The maximum frame length of the current batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a Tensor of shape (batch_size, num_frames)
|
||||||
|
"""
|
||||||
|
durations = [x + [num_frames - sum(x)] for x in durations]
|
||||||
|
batch_size = len(durations)
|
||||||
|
ans = torch.zeros(batch_size, num_frames, dtype=torch.int64)
|
||||||
|
for b in range(batch_size):
|
||||||
|
this_dur = durations[b]
|
||||||
|
cur_frame = 0
|
||||||
|
for i, d in enumerate(this_dur):
|
||||||
|
ans[b, cur_frame : cur_frame + d] = i
|
||||||
|
cur_frame += d
|
||||||
|
assert cur_frame == num_frames, (cur_frame, num_frames)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def to_int_tuple(s: str):
|
||||||
|
return tuple(map(int, s.split(",")))
|
||||||
|
|
||||||
|
|
||||||
|
def get_adjusted_batch_count(params: AttributeDict) -> float:
|
||||||
|
# returns the number of batches we would have used so far if we had used the reference
|
||||||
|
# duration. This is for purposes of set_batch_count().
|
||||||
|
return (
|
||||||
|
params.batch_idx_train
|
||||||
|
* (params.max_duration * params.world_size)
|
||||||
|
/ params.ref_duration
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||||
|
if isinstance(model, DDP):
|
||||||
|
# get underlying nn.Module
|
||||||
|
model = model.module
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if hasattr(module, "batch_count"):
|
||||||
|
module.batch_count = batch_count
|
||||||
|
if hasattr(module, "name"):
|
||||||
|
module.name = name
|
||||||
|
|
||||||
|
|
||||||
|
def condition_time_mask(
|
||||||
|
features_lens: torch.Tensor, mask_percent: Tuple[float, float], max_len: int = 0
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply Time masking.
|
||||||
|
Args:
|
||||||
|
features_lens:
|
||||||
|
input tensor of shape ``(B)``
|
||||||
|
mask_size:
|
||||||
|
the width size for masking.
|
||||||
|
max_len:
|
||||||
|
the maximum length of the mask.
|
||||||
|
Returns:
|
||||||
|
Return a 2-D bool tensor (B, T), where masked positions
|
||||||
|
are filled with `True` and non-masked positions are
|
||||||
|
filled with `False`.
|
||||||
|
"""
|
||||||
|
mask_size = (
|
||||||
|
torch.zeros_like(features_lens, dtype=torch.float32).uniform_(*mask_percent)
|
||||||
|
* features_lens
|
||||||
|
).to(torch.int64)
|
||||||
|
mask_starts = (
|
||||||
|
torch.rand_like(mask_size, dtype=torch.float32) * (features_lens - mask_size)
|
||||||
|
).to(torch.int64)
|
||||||
|
mask_ends = mask_starts + mask_size
|
||||||
|
max_len = max(max_len, features_lens.max())
|
||||||
|
seq_range = torch.arange(0, max_len, device=features_lens.device)
|
||||||
|
mask = (seq_range[None, :] >= mask_starts[:, None]) & (
|
||||||
|
seq_range[None, :] < mask_ends[:, None]
|
||||||
|
)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lengths:
|
||||||
|
A 1-D tensor containing sentence lengths.
|
||||||
|
max_len:
|
||||||
|
The length of masks.
|
||||||
|
Returns:
|
||||||
|
Return a 2-D bool tensor, where masked positions
|
||||||
|
are filled with `True` and non-masked positions are
|
||||||
|
filled with `False`.
|
||||||
|
|
||||||
|
>>> lengths = torch.tensor([1, 3, 2, 5])
|
||||||
|
>>> make_pad_mask(lengths)
|
||||||
|
tensor([[False, True, True, True, True],
|
||||||
|
[False, False, False, True, True],
|
||||||
|
[False, False, True, True, True],
|
||||||
|
[False, False, False, False, False]])
|
||||||
|
"""
|
||||||
|
assert lengths.ndim == 1, lengths.ndim
|
||||||
|
max_len = max(max_len, lengths.max())
|
||||||
|
n = lengths.size(0)
|
||||||
|
seq_range = torch.arange(0, max_len, device=lengths.device)
|
||||||
|
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
||||||
|
|
||||||
|
return expaned_lengths >= lengths.unsqueeze(-1)
|
1648
egs/zipvoice/zipvoice/zipformer.py
Normal file
1648
egs/zipvoice/zipvoice/zipformer.py
Normal file
File diff suppressed because it is too large
Load Diff
642
egs/zipvoice/zipvoice/zipvoice_infer.py
Normal file
642
egs/zipvoice/zipvoice/zipvoice_infer.py
Normal file
@ -0,0 +1,642 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script generates speech with our pre-trained ZipVoice or
|
||||||
|
ZipVoice-Distill models. Required models will be automatically
|
||||||
|
downloaded from HuggingFace.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
Note: If you having trouble connecting to HuggingFace,
|
||||||
|
you try switch endpoint to mirror site:
|
||||||
|
|
||||||
|
export HF_ENDPOINT=https://hf-mirror.com
|
||||||
|
|
||||||
|
(1) Inference of a single sentence:
|
||||||
|
|
||||||
|
python3 zipvoice/zipvoice_infer.py \
|
||||||
|
--model-name "zipvoice_distill" \
|
||||||
|
--prompt-wav prompt.wav \
|
||||||
|
--prompt-text "I am a prompt." \
|
||||||
|
--text "I am a sentence." \
|
||||||
|
--res-wav-path result.wav
|
||||||
|
|
||||||
|
(2) Inference of a list of sentences:
|
||||||
|
python3 zipvoice/zipvoice_infer.py \
|
||||||
|
--model-name "zipvoice-distill" \
|
||||||
|
--test-list test.tsv \
|
||||||
|
--res-dir results
|
||||||
|
|
||||||
|
`--model-name` can be `zipvoice` or `zipvoice_distill`,
|
||||||
|
which are the models before and after distillation, respectively.
|
||||||
|
|
||||||
|
Each line of `test.tsv` is in the format of
|
||||||
|
`{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime as dt
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import safetensors.torch
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchaudio
|
||||||
|
from feature import TorchAudioFbank, TorchAudioFbankConfig
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from model import get_distill_model, get_model
|
||||||
|
from tokenizer import TokenizerEmilia
|
||||||
|
from utils import AttributeDict
|
||||||
|
from vocos import Vocos
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
type=str,
|
||||||
|
default="zipvoice_distill",
|
||||||
|
choices=["zipvoice", "zipvoice_distill"],
|
||||||
|
help="The model used for inference",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-list",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The list of prompt speech, prompt_transcription, "
|
||||||
|
"and text to synthesizein the format of "
|
||||||
|
"'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-wav",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The prompt wav to mimic",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-text",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The transcription of the prompt wav",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--text",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The text to synthesize",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--res-dir",
|
||||||
|
type=str,
|
||||||
|
default="results",
|
||||||
|
help="Path name of the generated wavs dir, "
|
||||||
|
"used when decdode-list is not None",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--res-wav-path",
|
||||||
|
type=str,
|
||||||
|
default="result.wav",
|
||||||
|
help="Path name of the generated wav path, " "used when decdode-list is None",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--guidance-scale",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="The scale of classifier-free guidance during inference.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-step",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="The number of sampling steps.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--feat-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="The scale factor of fbank feature",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--speed",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Control speech speed, 1.0 means normal, >1.0 means speed up",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--t-shift",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="Shift t to smaller ones if t_shift < 1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-rms",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="Target speech normalization rms value",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=666,
|
||||||
|
help="Random seed",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--fm-decoder-downsampling-factor",
|
||||||
|
type=str,
|
||||||
|
default="1,2,4,2,1",
|
||||||
|
help="Downsampling factor for each stack of encoder layers.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--fm-decoder-num-layers",
|
||||||
|
type=str,
|
||||||
|
default="2,2,4,4,4",
|
||||||
|
help="Number of zipformer encoder layers per stack, comma separated.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--fm-decoder-cnn-module-kernel",
|
||||||
|
type=str,
|
||||||
|
default="31,15,7,15,31",
|
||||||
|
help="Sizes of convolutional kernels in convolution modules "
|
||||||
|
"in each encoder stack: a single int or comma-separated list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--fm-decoder-feedforward-dim",
|
||||||
|
type=int,
|
||||||
|
default=1536,
|
||||||
|
help="Feedforward dimension of the zipformer encoder layers, "
|
||||||
|
"per stack, comma separated.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--fm-decoder-num-heads",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of attention heads in the zipformer encoder layers: "
|
||||||
|
"a single int or comma-separated list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--fm-decoder-dim",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Embedding dimension in encoder stacks: a single int "
|
||||||
|
"or comma-separated list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--text-encoder-downsampling-factor",
|
||||||
|
type=str,
|
||||||
|
default="1",
|
||||||
|
help="Downsampling factor for each stack of encoder layers.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--text-encoder-num-layers",
|
||||||
|
type=str,
|
||||||
|
default="4",
|
||||||
|
help="Number of zipformer encoder layers per stack, comma separated.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--text-encoder-feedforward-dim",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Feedforward dimension of the zipformer encoder layers, "
|
||||||
|
"per stack, comma separated.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--text-encoder-cnn-module-kernel",
|
||||||
|
type=str,
|
||||||
|
default="9",
|
||||||
|
help="Sizes of convolutional kernels in convolution modules in "
|
||||||
|
"each encoder stack: a single int or comma-separated list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--text-encoder-num-heads",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of attention heads in the zipformer encoder layers: "
|
||||||
|
"a single int or comma-separated list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--text-encoder-dim",
|
||||||
|
type=int,
|
||||||
|
default=192,
|
||||||
|
help="Embedding dimension in encoder stacks: "
|
||||||
|
"a single int or comma-separated list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--query-head-dim",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Query/key dimension per head in encoder stacks: "
|
||||||
|
"a single int or comma-separated list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--value-head-dim",
|
||||||
|
type=int,
|
||||||
|
default=12,
|
||||||
|
help="Value dimension per head in encoder stacks: "
|
||||||
|
"a single int or comma-separated list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--pos-head-dim",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Positional-encoding dimension per head in encoder stacks: "
|
||||||
|
"a single int or comma-separated list.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--pos-dim",
|
||||||
|
type=int,
|
||||||
|
default=48,
|
||||||
|
help="Positional-encoding embedding dimension",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--time-embed-dim",
|
||||||
|
type=int,
|
||||||
|
default=192,
|
||||||
|
help="Embedding dimension of timestamps embedding.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--text-embed-dim",
|
||||||
|
type=int,
|
||||||
|
default=192,
|
||||||
|
help="Embedding dimension of text embedding.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"sampling_rate": 24000,
|
||||||
|
"frame_shift_ms": 256 / 24000 * 1000,
|
||||||
|
"feat_dim": 100,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def get_vocoder():
|
||||||
|
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||||
|
return vocoder
|
||||||
|
|
||||||
|
|
||||||
|
def generate_sentence(
|
||||||
|
save_path: str,
|
||||||
|
prompt_text: str,
|
||||||
|
prompt_wav: str,
|
||||||
|
text: str,
|
||||||
|
model: nn.Module,
|
||||||
|
vocoder: nn.Module,
|
||||||
|
tokenizer: TokenizerEmilia,
|
||||||
|
feature_extractor: TorchAudioFbank,
|
||||||
|
device: torch.device,
|
||||||
|
num_step: int = 16,
|
||||||
|
guidance_scale: float = 1.0,
|
||||||
|
speed: float = 1.0,
|
||||||
|
t_shift: float = 0.5,
|
||||||
|
target_rms: float = 0.1,
|
||||||
|
feat_scale: float = 0.1,
|
||||||
|
sampling_rate: int = 24000,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate waveform of a text based on a given prompt
|
||||||
|
waveform and its transcription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_path (str): Path to save the generated wav.
|
||||||
|
prompt_text (str): Transcription of the prompt wav.
|
||||||
|
prompt_wav (str): Path to the prompt wav file.
|
||||||
|
text (str): Text to be synthesized into a waveform.
|
||||||
|
model (nn.Module): The model used for generation.
|
||||||
|
vocoder (nn.Module): The vocoder used to convert features to waveforms.
|
||||||
|
tokenizer (TokenizerEmilia): The tokenizer used to convert text to tokens.
|
||||||
|
feature_extractor (TorchAudioFbank): The feature extractor used to
|
||||||
|
extract acoustic features.
|
||||||
|
device (torch.device): The device on which computations are performed.
|
||||||
|
num_step (int, optional): Number of steps for decoding. Defaults to 16.
|
||||||
|
guidance_scale (float, optional): Scale for classifier-free guidance.
|
||||||
|
Defaults to 1.0.
|
||||||
|
speed (float, optional): Speed control. Defaults to 1.0.
|
||||||
|
t_shift (float, optional): Time shift. Defaults to 0.5.
|
||||||
|
target_rms (float, optional): Target RMS for waveform normalization.
|
||||||
|
Defaults to 0.1.
|
||||||
|
feat_scale (float, optional): Scale for features.
|
||||||
|
Defaults to 0.1.
|
||||||
|
sampling_rate (int, optional): Sampling rate for the waveform.
|
||||||
|
Defaults to 24000.
|
||||||
|
Returns:
|
||||||
|
metrics (dict): Dictionary containing time and real-time
|
||||||
|
factor metrics for processing.
|
||||||
|
"""
|
||||||
|
# Convert text to tokens
|
||||||
|
tokens = tokenizer.texts_to_token_ids([text])
|
||||||
|
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
|
||||||
|
|
||||||
|
# Load and preprocess prompt wav
|
||||||
|
prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
|
||||||
|
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
|
||||||
|
if prompt_rms < target_rms:
|
||||||
|
prompt_wav = prompt_wav * target_rms / prompt_rms
|
||||||
|
|
||||||
|
if prompt_sampling_rate != sampling_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(
|
||||||
|
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
|
||||||
|
)
|
||||||
|
prompt_wav = resampler(prompt_wav)
|
||||||
|
|
||||||
|
# Extract features from prompt wav
|
||||||
|
prompt_features = feature_extractor.extract(
|
||||||
|
prompt_wav, sampling_rate=sampling_rate
|
||||||
|
).to(device)
|
||||||
|
prompt_features = prompt_features.unsqueeze(0) * feat_scale
|
||||||
|
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
|
||||||
|
|
||||||
|
# Start timing
|
||||||
|
start_t = dt.datetime.now()
|
||||||
|
|
||||||
|
# Generate features
|
||||||
|
(
|
||||||
|
pred_features,
|
||||||
|
pred_features_lens,
|
||||||
|
pred_prompt_features,
|
||||||
|
pred_prompt_features_lens,
|
||||||
|
) = model.sample(
|
||||||
|
tokens=tokens,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
prompt_features=prompt_features,
|
||||||
|
prompt_features_lens=prompt_features_lens,
|
||||||
|
speed=speed,
|
||||||
|
t_shift=t_shift,
|
||||||
|
duration="predict",
|
||||||
|
num_step=num_step,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Postprocess predicted features
|
||||||
|
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
|
||||||
|
|
||||||
|
# Start vocoder processing
|
||||||
|
start_vocoder_t = dt.datetime.now()
|
||||||
|
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
|
||||||
|
|
||||||
|
# Calculate processing times and real-time factors
|
||||||
|
t = (dt.datetime.now() - start_t).total_seconds()
|
||||||
|
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
|
||||||
|
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
|
||||||
|
wav_seconds = wav.shape[-1] / sampling_rate
|
||||||
|
rtf = t / wav_seconds
|
||||||
|
rtf_no_vocoder = t_no_vocoder / wav_seconds
|
||||||
|
rtf_vocoder = t_vocoder / wav_seconds
|
||||||
|
metrics = {
|
||||||
|
"t": t,
|
||||||
|
"t_no_vocoder": t_no_vocoder,
|
||||||
|
"t_vocoder": t_vocoder,
|
||||||
|
"wav_seconds": wav_seconds,
|
||||||
|
"rtf": rtf,
|
||||||
|
"rtf_no_vocoder": rtf_no_vocoder,
|
||||||
|
"rtf_vocoder": rtf_vocoder,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Adjust wav volume if necessary
|
||||||
|
if prompt_rms < target_rms:
|
||||||
|
wav = wav * prompt_rms / target_rms
|
||||||
|
wav = wav[0].cpu().numpy()
|
||||||
|
sf.write(save_path, wav, sampling_rate)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
res_dir: str,
|
||||||
|
test_list: str,
|
||||||
|
model: nn.Module,
|
||||||
|
vocoder: nn.Module,
|
||||||
|
tokenizer: TokenizerEmilia,
|
||||||
|
feature_extractor: TorchAudioFbank,
|
||||||
|
device: torch.device,
|
||||||
|
num_step: int = 16,
|
||||||
|
guidance_scale: float = 1.0,
|
||||||
|
speed: float = 1.0,
|
||||||
|
t_shift: float = 0.5,
|
||||||
|
target_rms: float = 0.1,
|
||||||
|
feat_scale: float = 0.1,
|
||||||
|
sampling_rate: int = 24000,
|
||||||
|
):
|
||||||
|
total_t = []
|
||||||
|
total_t_no_vocoder = []
|
||||||
|
total_t_vocoder = []
|
||||||
|
total_wav_seconds = []
|
||||||
|
|
||||||
|
with open(test_list, "r") as fr:
|
||||||
|
lines = fr.readlines()
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
|
||||||
|
save_path = f"{res_dir}/{wav_name}.wav"
|
||||||
|
metrics = generate_sentence(
|
||||||
|
save_path=save_path,
|
||||||
|
prompt_text=prompt_text,
|
||||||
|
prompt_wav=prompt_wav,
|
||||||
|
text=text,
|
||||||
|
model=model,
|
||||||
|
vocoder=vocoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
device=device,
|
||||||
|
num_step=num_step,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
speed=speed,
|
||||||
|
t_shift=t_shift,
|
||||||
|
target_rms=target_rms,
|
||||||
|
feat_scale=feat_scale,
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
)
|
||||||
|
print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
|
||||||
|
total_t.append(metrics["t"])
|
||||||
|
total_t_no_vocoder.append(metrics["t_no_vocoder"])
|
||||||
|
total_t_vocoder.append(metrics["t_vocoder"])
|
||||||
|
total_wav_seconds.append(metrics["wav_seconds"])
|
||||||
|
|
||||||
|
print(f"Average RTF: {np.sum(total_t)/np.sum(total_wav_seconds):.4f}")
|
||||||
|
print(
|
||||||
|
f"Average RTF w/o vocoder: "
|
||||||
|
f"{np.sum(total_t_no_vocoder)/np.sum(total_wav_seconds):.4f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Average RTF vocoder: "
|
||||||
|
f"{np.sum(total_t_vocoder)/np.sum(total_wav_seconds):.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
model_defaults = {
|
||||||
|
"zipvoice": {
|
||||||
|
"num_step": 16,
|
||||||
|
"guidance_scale": 1.0,
|
||||||
|
},
|
||||||
|
"zipvoice_distill": {
|
||||||
|
"num_step": 8,
|
||||||
|
"guidance_scale": 3.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
model_specific_defaults = model_defaults.get(params.model_name, {})
|
||||||
|
|
||||||
|
for param, value in model_specific_defaults.items():
|
||||||
|
if getattr(params, param) == parser.get_default(param):
|
||||||
|
setattr(params, param, value)
|
||||||
|
print(f"Setting {param} to default value: {value}")
|
||||||
|
|
||||||
|
assert (params.test_list is not None) ^ (
|
||||||
|
(params.prompt_wav and params.prompt_text and params.text) is not None
|
||||||
|
), (
|
||||||
|
"For inference, please provide prompts and text with either '--test-list'"
|
||||||
|
" or '--prompt-wav, --prompt-text and --text'."
|
||||||
|
)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
params.device = torch.device("cuda", 0)
|
||||||
|
else:
|
||||||
|
params.device = torch.device("cpu")
|
||||||
|
|
||||||
|
token_file = hf_hub_download("zhu-han/ZipVoice", filename="tokens_emilia.txt")
|
||||||
|
|
||||||
|
tokenizer = TokenizerEmilia(token_file)
|
||||||
|
|
||||||
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
params.pad_id = tokenizer.pad_id
|
||||||
|
fix_random_seed(params.seed)
|
||||||
|
|
||||||
|
if params.model_name == "zipvoice_distill":
|
||||||
|
model = get_distill_model(params)
|
||||||
|
model_ckpt = hf_hub_download(
|
||||||
|
"zhu-han/ZipVoice", filename="exp_zipvoice_distill/model.safetensors"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = get_model(params)
|
||||||
|
model_ckpt = hf_hub_download(
|
||||||
|
"zhu-han/ZipVoice", filename="exp_zipvoice/model.safetensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
safetensors.torch.load_model(model, model_ckpt)
|
||||||
|
|
||||||
|
model = model.to(params.device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
vocoder = get_vocoder()
|
||||||
|
vocoder = vocoder.to(params.device)
|
||||||
|
vocoder.eval()
|
||||||
|
|
||||||
|
config = TorchAudioFbankConfig(
|
||||||
|
sampling_rate=params.sampling_rate,
|
||||||
|
n_mels=100,
|
||||||
|
n_fft=1024,
|
||||||
|
hop_length=256,
|
||||||
|
)
|
||||||
|
feature_extractor = TorchAudioFbank(config)
|
||||||
|
|
||||||
|
if params.test_list:
|
||||||
|
os.makedirs(params.res_dir, exist_ok=True)
|
||||||
|
generate(
|
||||||
|
res_dir=params.res_dir,
|
||||||
|
test_list=params.test_list,
|
||||||
|
model=model,
|
||||||
|
vocoder=vocoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
device=params.device,
|
||||||
|
num_step=params.num_step,
|
||||||
|
guidance_scale=params.guidance_scale,
|
||||||
|
speed=params.speed,
|
||||||
|
t_shift=params.t_shift,
|
||||||
|
target_rms=params.target_rms,
|
||||||
|
feat_scale=params.feat_scale,
|
||||||
|
sampling_rate=params.sampling_rate,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
generate_sentence(
|
||||||
|
save_path=params.res_wav_path,
|
||||||
|
prompt_text=params.prompt_text,
|
||||||
|
prompt_wav=params.prompt_wav,
|
||||||
|
text=params.text,
|
||||||
|
model=model,
|
||||||
|
vocoder=vocoder,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
device=params.device,
|
||||||
|
num_step=params.num_step,
|
||||||
|
guidance_scale=params.guidance_scale,
|
||||||
|
speed=params.speed,
|
||||||
|
t_shift=params.t_shift,
|
||||||
|
target_rms=params.target_rms,
|
||||||
|
feat_scale=params.feat_scale,
|
||||||
|
sampling_rate=params.sampling_rate,
|
||||||
|
)
|
||||||
|
print("Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user