mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
rm zipvoice
This commit is contained in:
parent
343b8fa2dc
commit
0c9bd934c2
@ -1,412 +0,0 @@
|
||||
## ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching
|
||||
|
||||
|
||||
[](http://arxiv.org/abs/2506.13053)
|
||||
[](https://zipvoice.github.io/)
|
||||
|
||||
|
||||
## Overview
|
||||
ZipVoice is a high-quality zero-shot TTS model with a small model size and fast inference speed.
|
||||
#### Key features:
|
||||
|
||||
- Small and fast: only 123M parameters.
|
||||
|
||||
- High-quality: state-of-the-art voice cloning performance in speaker similarity, intelligibility, and naturalness.
|
||||
|
||||
- Multi-lingual: support Chinese and English.
|
||||
|
||||
|
||||
## News
|
||||
**2025/06/16**: 🔥 ZipVoice is released.
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
* Clone icefall repository and change to zipvoice directory:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/k2-fsa/icefall.git
|
||||
cd icefall/egs/zipvoice
|
||||
```
|
||||
|
||||
* Create a Python virtual environment (optional but recommended):
|
||||
|
||||
```bash
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
```
|
||||
|
||||
* Install the required packages:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
To generate speech with our pre-trained ZipVoice or ZipVoice-Distill models, use the following commands (Required models will be downloaded from HuggingFace):
|
||||
|
||||
### 1. Inference of a single sentence:
|
||||
```bash
|
||||
python3 zipvoice/zipvoice_infer.py \
|
||||
--model-name "zipvoice_distill" \
|
||||
--prompt-wav prompt.wav \
|
||||
--prompt-text "I am the transcription of the prompt wav." \
|
||||
--text "I am the text to be synthesized." \
|
||||
--res-wav-path result.wav
|
||||
|
||||
# Example with a pre-defined prompt wav and text
|
||||
python3 zipvoice/zipvoice_infer.py \
|
||||
--model-name "zipvoice_distill" \
|
||||
--prompt-wav assets/prompt-en.wav \
|
||||
--prompt-text "Some call me nature, others call me mother nature. I've been here for over four point five billion years, twenty two thousand five hundred times longer than you." \
|
||||
--text "Welcome to use our tts model, have fun!" \
|
||||
--res-wav-path result.wav
|
||||
```
|
||||
|
||||
### 2. Inference of a list of sentences:
|
||||
```bash
|
||||
python3 zipvoice/zipvoice_infer.py \
|
||||
--model-name "zipvoice_distill" \
|
||||
--test-list test.tsv \
|
||||
--res-dir results/test
|
||||
```
|
||||
|
||||
- `--model-name` can be `zipvoice` or `zipvoice_distill`, which are models before and after distillation, respectively.
|
||||
- Each line of `test.tsv` is in the format of `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
|
||||
|
||||
|
||||
> **Note:** If you having trouble connecting to HuggingFace, try:
|
||||
```bash
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
```
|
||||
|
||||
## Training Your Own Model
|
||||
|
||||
The following steps show how to train a model from scratch on Emilia and LibriTTS datasets, respectively.
|
||||
|
||||
### 0. Install dependencies for training
|
||||
|
||||
```bash
|
||||
# Install pytorch and k2.
|
||||
# If you want to use different versions, please refer to https://k2-fsa.org/get-started/k2/ for details.
|
||||
# For users in China mainland, please refer to https://k2-fsa.org/zh-CN/get-started/k2/
|
||||
|
||||
# Note: Make sure you have installed the correct version of PyTorch and k2 that matches your CUDA version.
|
||||
# For example, if want to use pytorch 2.5.1 and you are using CUDA 12.1, you can install PyTorch and k2 as follows:
|
||||
|
||||
pip install torch==2.5.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
|
||||
pip install k2==1.24.4.dev20250208+cuda12.1.torch2.5.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
|
||||
pip install -r ../../requirements.txt
|
||||
```
|
||||
|
||||
### 1. Data Preparation
|
||||
|
||||
#### 1.1. Prepare the Emilia dataset
|
||||
|
||||
```bash
|
||||
bash scripts/prepare_emilia.sh
|
||||
```
|
||||
|
||||
See [scripts/prepare_emilia.sh](scripts/prepare_emilia.sh) for step by step instructions.
|
||||
|
||||
#### 1.2 Prepare the LibriTTS dataset
|
||||
|
||||
```bash
|
||||
bash scripts/prepare_libritts.sh
|
||||
```
|
||||
|
||||
See [scripts/prepare_libritts.sh](scripts/prepare_libritts.sh) for step by step instructions.
|
||||
|
||||
### 2. Training
|
||||
|
||||
#### 2.1 Traininig on Emilia
|
||||
|
||||
<details>
|
||||
<summary>Expand to view training steps</summary>
|
||||
|
||||
##### 2.1.1 Train the ZipVoice model
|
||||
|
||||
- Training:
|
||||
|
||||
```bash
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
python3 zipvoice/train_flow.py \
|
||||
--world-size 8 \
|
||||
--use-fp16 1 \
|
||||
--dataset emilia \
|
||||
--max-duration 500 \
|
||||
--lr-hours 30000 \
|
||||
--lr-batches 7500 \
|
||||
--token-file "data/tokens_emilia.txt" \
|
||||
--manifest-dir "data/fbank" \
|
||||
--num-epochs 11 \
|
||||
--exp-dir zipvoice/exp_zipvoice
|
||||
```
|
||||
|
||||
- Average the checkpoints to produce the final model:
|
||||
|
||||
```bash
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
python3 zipvoice/generate_averaged_model.py \
|
||||
--epoch 11 \
|
||||
--avg 4 \
|
||||
--distill 0 \
|
||||
--token-file data/tokens_emilia.txt \
|
||||
--dataset "emilia" \
|
||||
--exp-dir ./zipvoice/exp_zipvoice
|
||||
# The generated model is zipvoice/exp_zipvoice/epoch-11-avg-4.pt
|
||||
```
|
||||
|
||||
##### 2.1.2. Train the ZipVoice-Distill model (Optional)
|
||||
|
||||
- The first-stage distillation:
|
||||
|
||||
```bash
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
python3 zipvoice/train_distill.py \
|
||||
--world-size 8 \
|
||||
--use-fp16 1 \
|
||||
--tensorboard 1 \
|
||||
--dataset "emilia" \
|
||||
--base-lr 0.0005 \
|
||||
--max-duration 500 \
|
||||
--token-file "data/tokens_emilia.txt" \
|
||||
--manifest-dir "data/fbank" \
|
||||
--teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
|
||||
--num-updates 60000 \
|
||||
--distill-stage "first" \
|
||||
--exp-dir zipvoice/exp_zipvoice_distill_1stage
|
||||
```
|
||||
|
||||
- Average checkpoints for the second-stage initialization:
|
||||
|
||||
```bash
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
python3 zipvoice/generate_averaged_model.py \
|
||||
--iter 60000 \
|
||||
--avg 7 \
|
||||
--distill 1 \
|
||||
--token-file data/tokens_emilia.txt \
|
||||
--dataset "emilia" \
|
||||
--exp-dir ./zipvoice/exp_zipvoice_distill_1stage
|
||||
# The generated model is zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt
|
||||
```
|
||||
|
||||
- The second-stage distillation:
|
||||
|
||||
```bash
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
python3 zipvoice/train_distill.py \
|
||||
--world-size 8 \
|
||||
--use-fp16 1 \
|
||||
--tensorboard 1 \
|
||||
--dataset "emilia" \
|
||||
--base-lr 0.0001 \
|
||||
--max-duration 200 \
|
||||
--token-file "data/tokens_emilia.txt" \
|
||||
--manifest-dir "data/fbank" \
|
||||
--teacher-model zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt \
|
||||
--num-updates 2000 \
|
||||
--distill-stage "second" \
|
||||
--exp-dir zipvoice/exp_zipvoice_distill_new
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
#### 2.2 Traininig on LibriTTS
|
||||
|
||||
<details>
|
||||
<summary>Expand to view training steps</summary>
|
||||
|
||||
##### 2.2.1 Train the ZipVoice model
|
||||
|
||||
- Training:
|
||||
|
||||
```bash
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
python3 zipvoice/train_flow.py \
|
||||
--world-size 8 \
|
||||
--use-fp16 1 \
|
||||
--dataset libritts \
|
||||
--max-duration 250 \
|
||||
--lr-epochs 10 \
|
||||
--lr-batches 7500 \
|
||||
--token-file "data/tokens_libritts.txt" \
|
||||
--manifest-dir "data/fbank" \
|
||||
--num-epochs 60 \
|
||||
--exp-dir zipvoice/exp_zipvoice_libritts
|
||||
```
|
||||
|
||||
- Average the checkpoints to produce the final model:
|
||||
|
||||
```bash
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
python3 zipvoice/generate_averaged_model.py \
|
||||
--epoch 60 \
|
||||
--avg 10 \
|
||||
--distill 0 \
|
||||
--token-file data/tokens_libritts.txt \
|
||||
--dataset "libritts" \
|
||||
--exp-dir ./zipvoice/exp_zipvoice_libritts
|
||||
# The generated model is zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt
|
||||
```
|
||||
|
||||
##### 2.1.2 Train the ZipVoice-Distill model (Optional)
|
||||
|
||||
- The first-stage distillation:
|
||||
|
||||
```bash
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
python3 zipvoice/train_distill.py \
|
||||
--world-size 8 \
|
||||
--use-fp16 1 \
|
||||
--tensorboard 1 \
|
||||
--dataset "libritts" \
|
||||
--base-lr 0.001 \
|
||||
--max-duration 250 \
|
||||
--token-file "data/tokens_libritts.txt" \
|
||||
--manifest-dir "data/fbank" \
|
||||
--teacher-model zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt \
|
||||
--num-epochs 6 \
|
||||
--distill-stage "first" \
|
||||
--exp-dir zipvoice/exp_zipvoice_distill_1stage_libritts
|
||||
```
|
||||
|
||||
- Average checkpoints for the second-stage initialization:
|
||||
|
||||
```bash
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
python3 ./zipvoice/generate_averaged_model.py \
|
||||
--epoch 6 \
|
||||
--avg 3 \
|
||||
--distill 1 \
|
||||
--token-file data/tokens_libritts.txt \
|
||||
--dataset "libritts" \
|
||||
--exp-dir ./zipvoice/exp_zipvoice_distill_1stage_libritts
|
||||
# The generated model is zipvoice/exp_zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt
|
||||
```
|
||||
|
||||
- The second-stage distillation:
|
||||
|
||||
```bash
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
python3 zipvoice/train_distill.py \
|
||||
--world-size 8 \
|
||||
--use-fp16 1 \
|
||||
--tensorboard 1 \
|
||||
--dataset "libritts" \
|
||||
--base-lr 0.001 \
|
||||
--max-duration 250 \
|
||||
--token-file "data/tokens_libritts.txt" \
|
||||
--manifest-dir "data/fbank" \
|
||||
--teacher-model zipvoice/exp_zipvoice_distill_1stage_libritts/epoch-6-avg-3.pt \
|
||||
--num-epochs 6 \
|
||||
--distill-stage "second" \
|
||||
--exp-dir zipvoice/exp_zipvoice_distill_libritts
|
||||
```
|
||||
|
||||
- Average checkpoints to produce the final model:
|
||||
|
||||
```bash
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
python3 ./zipvoice/generate_averaged_model.py \
|
||||
--epoch 6 \
|
||||
--avg 3 \
|
||||
--distill 1 \
|
||||
--token-file data/tokens_libritts.txt \
|
||||
--dataset "libritts" \
|
||||
--exp-dir ./zipvoice/exp_zipvoice_distill_libritts
|
||||
# The generated model is ./zipvoice/exp_zipvoice_distill_libritts/epoch-6-avg-3.pt
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
### 3. Inference with the trained model
|
||||
|
||||
#### 3.1 Inference with the model trained on Emilia
|
||||
<details>
|
||||
<summary>Expand to view inference commands.</summary>
|
||||
|
||||
##### 3.1.1 ZipVoice model before distill:
|
||||
```bash
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
python3 zipvoice/infer.py \
|
||||
--checkpoint zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
|
||||
--distill 0 \
|
||||
--token-file "data/tokens_emilia.txt" \
|
||||
--test-list test.tsv \
|
||||
--res-dir results/test \
|
||||
--num-step 16 \
|
||||
--guidance-scale 1
|
||||
```
|
||||
|
||||
##### 3.1.2 ZipVoice-Distill model before distill:
|
||||
```bash
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
python3 zipvoice/infer.py \
|
||||
--checkpoint zipvoice/exp_zipvoice_distill/checkpoint-2000.pt \
|
||||
--distill 1 \
|
||||
--token-file "data/tokens_emilia.txt" \
|
||||
--test-list test.tsv \
|
||||
--res-dir results/test_distill \
|
||||
--num-step 8 \
|
||||
--guidance-scale 3
|
||||
```
|
||||
</details>
|
||||
|
||||
|
||||
#### 3.2 Inference with the model trained on LibriTTS
|
||||
|
||||
<details>
|
||||
<summary>Expand to view inference commands.</summary>
|
||||
|
||||
##### 3.2.1 ZipVoice model before distill:
|
||||
```bash
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
python3 zipvoice/infer.py \
|
||||
--checkpoint zipvoice/exp_zipvoice_libritts/epoch-60-avg-10.pt \
|
||||
--distill 0 \
|
||||
--token-file "data/tokens_libritts.txt" \
|
||||
--test-list test.tsv \
|
||||
--res-dir results/test_libritts \
|
||||
--num-step 8 \
|
||||
--guidance-scale 1 \
|
||||
--target-rms 1.0 \
|
||||
--t-shift 0.7
|
||||
```
|
||||
|
||||
##### 3.2.2 ZipVoice-Distill model before distill
|
||||
|
||||
```bash
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
python3 zipvoice/infer.py \
|
||||
--checkpoint zipvoice/exp_zipvoice_distill/epoch-6-avg-3.pt \
|
||||
--distill 1 \
|
||||
--token-file "data/tokens_libritts.txt" \
|
||||
--test-list test.tsv \
|
||||
--res-dir results/test_distill_libritts \
|
||||
--num-step 4 \
|
||||
--guidance-scale 3 \
|
||||
--target-rms 1.0 \
|
||||
--t-shift 0.7
|
||||
```
|
||||
</details>
|
||||
|
||||
### 4. Evaluation on benchmarks
|
||||
|
||||
See [local/evaluate.sh](local/evaluate.sh) for details of objective metrics evaluation
|
||||
on three test sets, i.e., LibriSpeech-PC test-clean, Seed-TTS test-en and Seed-TTS test-zh.
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{zhu-2025-zipvoice,
|
||||
title={ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching},
|
||||
author={Han Zhu and Wei Kang and Zengwei Yao and Liyong Guo and Fangjun Kuang and Zhaoqing Li and Weiji Zhuang and Long Lin and Daniel Povey}
|
||||
journal={arXiv preprint arXiv:2506.13053},
|
||||
year={2025},
|
||||
}
|
||||
```
|
Binary file not shown.
@ -1,287 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor as Pool
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import lhotse
|
||||
import torch
|
||||
from feature import TorchAudioFbank, TorchAudioFbankConfig
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
LilcomChunkyWriter,
|
||||
load_manifest_lazy,
|
||||
set_audio_duration_mismatch_tolerance,
|
||||
)
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
"""Used in argparse.ArgumentParser.add_argument to indicate
|
||||
that a type is a bool type and user can enter
|
||||
|
||||
- yes, true, t, y, 1, to represent True
|
||||
- no, false, f, n, 0, to represent False
|
||||
|
||||
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--sampling-rate",
|
||||
type=int,
|
||||
default=24000,
|
||||
help="The target sampling rate, the audio will be resampled to this sampling_rate.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-shift",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Frame shift in samples",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--frame-length",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Frame length in samples",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=100,
|
||||
help="The num of mel filters.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
help="Dataset name.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--subset",
|
||||
type=str,
|
||||
help="The subset of the dataset.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--source-dir",
|
||||
type=str,
|
||||
default="data/manifests",
|
||||
help="The source directory of manifest files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dest-dir",
|
||||
type=str,
|
||||
default="data/fbank",
|
||||
help="The destination directory of manifest files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--split-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use splited cuts.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--split-begin",
|
||||
type=int,
|
||||
help="Start idx of splited cuts.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--split-end",
|
||||
type=int,
|
||||
help="End idx of splited cuts.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-duration",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="The batch duration when computing the features.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-jobs", type=int, default=20, help="The number of extractor workers."
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compute_fbank_split_single(params, idx):
|
||||
lhotse.set_audio_duration_mismatch_tolerance(0.1) # for emilia
|
||||
src_dir = Path(params.source_dir)
|
||||
output_dir = Path(params.dest_dir)
|
||||
num_mel_bins = params.num_mel_bins
|
||||
|
||||
if not src_dir.exists():
|
||||
logging.error(f"{src_dir} not exists")
|
||||
return
|
||||
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
num_digits = 8
|
||||
|
||||
config = TorchAudioFbankConfig(
|
||||
sampling_rate=params.sampling_rate,
|
||||
n_mels=params.num_mel_bins,
|
||||
n_fft=params.frame_length,
|
||||
hop_length=params.frame_shift,
|
||||
)
|
||||
extractor = TorchAudioFbank(config)
|
||||
|
||||
prefix = params.dataset
|
||||
subset = params.subset
|
||||
suffix = "jsonl.gz"
|
||||
|
||||
idx = f"{idx}".zfill(num_digits)
|
||||
cuts_filename = f"{prefix}_cuts_{subset}.{idx}.{suffix}"
|
||||
|
||||
if (src_dir / cuts_filename).is_file():
|
||||
logging.info(f"Loading manifests {src_dir / cuts_filename}")
|
||||
cut_set = load_manifest_lazy(src_dir / cuts_filename)
|
||||
else:
|
||||
logging.warning(f"Raw {cuts_filename} not exists, skipping")
|
||||
return
|
||||
|
||||
cut_set = cut_set.resample(params.sampling_rate)
|
||||
|
||||
if (output_dir / cuts_filename).is_file():
|
||||
logging.info(f"{cuts_filename} already exists - skipping.")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {subset}.{idx} of {prefix}")
|
||||
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{subset}_{idx}",
|
||||
num_workers=4,
|
||||
batch_duration=params.batch_duration,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
cut_set.to_file(output_dir / cuts_filename)
|
||||
|
||||
|
||||
def compute_fbank_split(params):
|
||||
if params.split_end < params.split_begin:
|
||||
logging.warning(
|
||||
f"Split begin should be smaller than split end, given "
|
||||
f"{params.split_begin} -> {params.split_end}."
|
||||
)
|
||||
|
||||
with Pool(max_workers=params.num_jobs) as pool:
|
||||
futures = [
|
||||
pool.submit(compute_fbank_split_single, params, i)
|
||||
for i in range(params.split_begin, params.split_end)
|
||||
]
|
||||
for f in futures:
|
||||
f.result()
|
||||
f.done()
|
||||
|
||||
|
||||
def compute_fbank(params):
|
||||
src_dir = Path(params.source_dir)
|
||||
output_dir = Path(params.dest_dir)
|
||||
num_jobs = params.num_jobs
|
||||
num_mel_bins = params.num_mel_bins
|
||||
|
||||
prefix = params.dataset
|
||||
subset = params.subset
|
||||
suffix = "jsonl.gz"
|
||||
|
||||
cut_set_name = f"{prefix}_cuts_{subset}.{suffix}"
|
||||
|
||||
if (src_dir / cut_set_name).is_file():
|
||||
logging.info(f"Loading manifests {src_dir / cut_set_name}")
|
||||
cut_set = load_manifest_lazy(src_dir / cut_set_name)
|
||||
else:
|
||||
recordings = load_manifest_lazy(
|
||||
src_dir / f"{prefix}_recordings_{subset}.{suffix}"
|
||||
)
|
||||
supervisions = load_manifest_lazy(
|
||||
src_dir / f"{prefix}_supervisions_{subset}.{suffix}"
|
||||
)
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=recordings,
|
||||
supervisions=supervisions,
|
||||
)
|
||||
|
||||
cut_set = cut_set.resample(params.sampling_rate)
|
||||
|
||||
config = TorchAudioFbankConfig(
|
||||
sampling_rate=params.sampling_rate,
|
||||
n_mels=params.num_mel_bins,
|
||||
n_fft=params.frame_length,
|
||||
hop_length=params.frame_shift,
|
||||
)
|
||||
extractor = TorchAudioFbank(config)
|
||||
|
||||
cuts_filename = f"{prefix}_cuts_{subset}.{suffix}"
|
||||
if (output_dir / cuts_filename).is_file():
|
||||
logging.info(f"{prefix} {subset} already exists - skipping.")
|
||||
return
|
||||
logging.info(f"Processing {subset} of {prefix}")
|
||||
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{subset}",
|
||||
num_jobs=num_jobs,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cut_set.to_file(output_dir / cuts_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
if args.split_cuts:
|
||||
compute_fbank_split(params=args)
|
||||
else:
|
||||
compute_fbank(params=args)
|
@ -1,508 +0,0 @@
|
||||
"""
|
||||
Calculate pairwise Speaker Similarity betweeen two speech directories.
|
||||
SV model wavlm_large_finetune.pth is downloaded from
|
||||
https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification
|
||||
SSL model wavlm_large.pt is downloaded from
|
||||
https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--eval-path", type=str, help="path of the evaluated speech directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-list",
|
||||
type=str,
|
||||
help="path of the file list that contains the corresponding "
|
||||
"relationship between the prompt and evaluated speech. "
|
||||
"The first column is the wav name and the third column is the prompt speech",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sv-model-path",
|
||||
type=str,
|
||||
default="model/UniSpeech/wavlm_large_finetune.pth",
|
||||
help="path of the wavlm-based ECAPA-TDNN model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-model-path",
|
||||
type=str,
|
||||
default="model/s3prl/wavlm_large.pt",
|
||||
help="path of the wavlm SSL model",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
class SpeakerSimilarity:
|
||||
def __init__(
|
||||
self,
|
||||
sv_model_path="model/UniSpeech/wavlm_large_finetune.pth",
|
||||
ssl_model_path="model/s3prl/wavlm_large.pt",
|
||||
):
|
||||
"""
|
||||
Initialize
|
||||
"""
|
||||
self.sample_rate = 16000
|
||||
self.channels = 1
|
||||
self.device = (
|
||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
)
|
||||
logging.info("[Speaker Similarity] Using device: {}".format(self.device))
|
||||
self.model = ECAPA_TDNN_WAVLLM(
|
||||
feat_dim=1024,
|
||||
channels=512,
|
||||
emb_dim=256,
|
||||
sr=16000,
|
||||
ssl_model_path=ssl_model_path,
|
||||
)
|
||||
state_dict = torch.load(
|
||||
sv_model_path, map_location=lambda storage, loc: storage
|
||||
)
|
||||
self.model.load_state_dict(state_dict["model"], strict=False)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
def get_embeddings(self, wav_list, dtype="float32"):
|
||||
"""
|
||||
Get embeddings
|
||||
"""
|
||||
|
||||
def _load_speech_task(fname, sample_rate):
|
||||
|
||||
wav_data, sr = sf.read(fname, dtype=dtype)
|
||||
if sr != sample_rate:
|
||||
wav_data = librosa.resample(
|
||||
wav_data, orig_sr=sr, target_sr=self.sample_rate
|
||||
)
|
||||
wav_data = torch.from_numpy(wav_data)
|
||||
|
||||
return wav_data
|
||||
|
||||
embd_lst = []
|
||||
for file_path in tqdm(wav_list):
|
||||
speech = _load_speech_task(file_path, self.sample_rate)
|
||||
speech = speech.to(self.device)
|
||||
with torch.no_grad():
|
||||
embd = self.model([speech])
|
||||
embd_lst.append(embd)
|
||||
|
||||
return embd_lst
|
||||
|
||||
def score(
|
||||
self,
|
||||
eval_path,
|
||||
test_list,
|
||||
dtype="float32",
|
||||
):
|
||||
"""
|
||||
Computes the Speaker Similarity (SIM-o) between two directories of speech files.
|
||||
|
||||
Parameters:
|
||||
- eval_path (str): Path to the directory containing evaluation speech files.
|
||||
- test_list (str): Path to the file containing the corresponding relationship
|
||||
between prompt and evaluated speech.
|
||||
- dtype (str, optional): Data type for loading speech. Default is "float32".
|
||||
|
||||
Returns:
|
||||
- float: The Speaker Similarity (SIM-o) score between the two directories
|
||||
of speech files.
|
||||
"""
|
||||
prompt_wavs = []
|
||||
eval_wavs = []
|
||||
with open(test_list, "r") as fr:
|
||||
lines = fr.readlines()
|
||||
for line in lines:
|
||||
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
|
||||
prompt_wavs.append(prompt_wav)
|
||||
eval_wavs.append(os.path.join(eval_path, wav_name + ".wav"))
|
||||
embds_prompt = self.get_embeddings(prompt_wavs, dtype=dtype)
|
||||
|
||||
embds_eval = self.get_embeddings(eval_wavs, dtype=dtype)
|
||||
|
||||
# Check if embeddings are empty
|
||||
if len(embds_prompt) == 0:
|
||||
logging.info("[Speaker Similarity] real set dir is empty, exiting...")
|
||||
return -1
|
||||
if len(embds_eval) == 0:
|
||||
logging.info("[Speaker Similarity] eval set dir is empty, exiting...")
|
||||
return -1
|
||||
|
||||
scores = []
|
||||
for real_embd, eval_embd in zip(embds_prompt, embds_eval):
|
||||
scores.append(
|
||||
torch.nn.functional.cosine_similarity(real_embd, eval_embd, dim=-1)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
return np.mean(scores)
|
||||
|
||||
|
||||
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
||||
|
||||
""" Res2Conv1d + BatchNorm1d + ReLU
|
||||
"""
|
||||
|
||||
|
||||
class Res2Conv1dReluBn(nn.Module):
|
||||
"""
|
||||
in_channels == out_channels == channels
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
scale=4,
|
||||
):
|
||||
super().__init__()
|
||||
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
||||
self.scale = scale
|
||||
self.width = channels // scale
|
||||
self.nums = scale if scale == 1 else scale - 1
|
||||
|
||||
self.convs = []
|
||||
self.bns = []
|
||||
for i in range(self.nums):
|
||||
self.convs.append(
|
||||
nn.Conv1d(
|
||||
self.width,
|
||||
self.width,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
bias=bias,
|
||||
)
|
||||
)
|
||||
self.bns.append(nn.BatchNorm1d(self.width))
|
||||
self.convs = nn.ModuleList(self.convs)
|
||||
self.bns = nn.ModuleList(self.bns)
|
||||
|
||||
def forward(self, x):
|
||||
out = []
|
||||
spx = torch.split(x, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
# Order: conv -> relu -> bn
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.bns[i](F.relu(sp))
|
||||
out.append(sp)
|
||||
if self.scale != 1:
|
||||
out.append(spx[self.nums])
|
||||
out = torch.cat(out, dim=1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
""" Conv1d + BatchNorm1d + ReLU
|
||||
"""
|
||||
|
||||
|
||||
class Conv1dReluBn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias
|
||||
)
|
||||
self.bn = nn.BatchNorm1d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn(F.relu(self.conv(x)))
|
||||
|
||||
|
||||
""" The SE connection of 1D case.
|
||||
"""
|
||||
|
||||
|
||||
class SE_Connect(nn.Module):
|
||||
def __init__(self, channels, se_bottleneck_dim=128):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
||||
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
||||
|
||||
def forward(self, x):
|
||||
out = x.mean(dim=2)
|
||||
out = F.relu(self.linear1(out))
|
||||
out = torch.sigmoid(self.linear2(out))
|
||||
out = x * out.unsqueeze(2)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
""" SE-Res2Block of the ECAPA-TDNN architecture.
|
||||
"""
|
||||
|
||||
|
||||
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
||||
# return nn.Sequential(
|
||||
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
|
||||
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
|
||||
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
|
||||
# SE_Connect(channels)
|
||||
# )
|
||||
|
||||
|
||||
class SE_Res2Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
scale,
|
||||
se_bottleneck_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.Conv1dReluBn1 = Conv1dReluBn(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.Res2Conv1dReluBn = Res2Conv1dReluBn(
|
||||
out_channels, kernel_size, stride, padding, dilation, scale=scale
|
||||
)
|
||||
self.Conv1dReluBn2 = Conv1dReluBn(
|
||||
out_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
|
||||
|
||||
self.shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
if self.shortcut:
|
||||
residual = self.shortcut(x)
|
||||
|
||||
x = self.Conv1dReluBn1(x)
|
||||
x = self.Res2Conv1dReluBn(x)
|
||||
x = self.Conv1dReluBn2(x)
|
||||
x = self.SE_Connect(x)
|
||||
|
||||
return x + residual
|
||||
|
||||
|
||||
""" Attentive weighted mean and standard deviation pooling.
|
||||
"""
|
||||
|
||||
|
||||
class AttentiveStatsPool(nn.Module):
|
||||
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
|
||||
super().__init__()
|
||||
self.global_context_att = global_context_att
|
||||
|
||||
# Use Conv1d with stride == 1 rather than Linear,
|
||||
# then we don't need to transpose inputs.
|
||||
if global_context_att:
|
||||
self.linear1 = nn.Conv1d(
|
||||
in_dim * 3, attention_channels, kernel_size=1
|
||||
) # equals W and b in the paper
|
||||
else:
|
||||
self.linear1 = nn.Conv1d(
|
||||
in_dim, attention_channels, kernel_size=1
|
||||
) # equals W and b in the paper
|
||||
self.linear2 = nn.Conv1d(
|
||||
attention_channels, in_dim, kernel_size=1
|
||||
) # equals V and k in the paper
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.global_context_att:
|
||||
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
||||
context_std = torch.sqrt(
|
||||
torch.var(x, dim=-1, keepdim=True) + 1e-10
|
||||
).expand_as(x)
|
||||
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
||||
else:
|
||||
x_in = x
|
||||
|
||||
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
||||
alpha = torch.tanh(self.linear1(x_in))
|
||||
# alpha = F.relu(self.linear1(x_in))
|
||||
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||||
mean = torch.sum(alpha * x, dim=2)
|
||||
residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
|
||||
std = torch.sqrt(residuals.clamp(min=1e-9))
|
||||
return torch.cat([mean, std], dim=1)
|
||||
|
||||
|
||||
class ECAPA_TDNN_WAVLLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
feat_dim=80,
|
||||
channels=512,
|
||||
emb_dim=192,
|
||||
global_context_att=False,
|
||||
sr=16000,
|
||||
ssl_model_path=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.sr = sr
|
||||
|
||||
if ssl_model_path is None:
|
||||
self.feature_extract = torch.hub.load("s3prl/s3prl", "wavlm_large")
|
||||
else:
|
||||
self.feature_extract = torch.hub.load(
|
||||
os.path.dirname(ssl_model_path),
|
||||
"wavlm_local",
|
||||
source="local",
|
||||
ckpt=ssl_model_path,
|
||||
)
|
||||
|
||||
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
||||
self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
|
||||
):
|
||||
self.feature_extract.model.encoder.layers[
|
||||
23
|
||||
].self_attn.fp32_attention = False
|
||||
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
|
||||
self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
|
||||
):
|
||||
self.feature_extract.model.encoder.layers[
|
||||
11
|
||||
].self_attn.fp32_attention = False
|
||||
|
||||
self.feat_num = self.get_feat_num()
|
||||
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
||||
|
||||
self.instance_norm = nn.InstanceNorm1d(feat_dim)
|
||||
# self.channels = [channels] * 4 + [channels * 3]
|
||||
self.channels = [channels] * 4 + [1536]
|
||||
|
||||
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
|
||||
self.layer2 = SE_Res2Block(
|
||||
self.channels[0],
|
||||
self.channels[1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=2,
|
||||
dilation=2,
|
||||
scale=8,
|
||||
se_bottleneck_dim=128,
|
||||
)
|
||||
self.layer3 = SE_Res2Block(
|
||||
self.channels[1],
|
||||
self.channels[2],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=3,
|
||||
dilation=3,
|
||||
scale=8,
|
||||
se_bottleneck_dim=128,
|
||||
)
|
||||
self.layer4 = SE_Res2Block(
|
||||
self.channels[2],
|
||||
self.channels[3],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=4,
|
||||
dilation=4,
|
||||
scale=8,
|
||||
se_bottleneck_dim=128,
|
||||
)
|
||||
|
||||
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
||||
cat_channels = channels * 3
|
||||
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
||||
self.pooling = AttentiveStatsPool(
|
||||
self.channels[-1],
|
||||
attention_channels=128,
|
||||
global_context_att=global_context_att,
|
||||
)
|
||||
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
||||
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
||||
|
||||
def get_feat_num(self):
|
||||
self.feature_extract.eval()
|
||||
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
|
||||
with torch.no_grad():
|
||||
features = self.feature_extract(wav)
|
||||
select_feature = features["hidden_states"]
|
||||
if isinstance(select_feature, (list, tuple)):
|
||||
return len(select_feature)
|
||||
else:
|
||||
return 1
|
||||
|
||||
def get_feat(self, x):
|
||||
with torch.no_grad():
|
||||
x = self.feature_extract([sample for sample in x])
|
||||
|
||||
x = x["hidden_states"]
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = torch.stack(x, dim=0)
|
||||
else:
|
||||
x = x.unsqueeze(0)
|
||||
norm_weights = (
|
||||
F.softmax(self.feature_weight, dim=-1)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
)
|
||||
x = (norm_weights * x).sum(dim=0)
|
||||
x = torch.transpose(x, 1, 2) + 1e-6
|
||||
|
||||
x = self.instance_norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.get_feat(x)
|
||||
|
||||
out1 = self.layer1(x)
|
||||
out2 = self.layer2(out1)
|
||||
out3 = self.layer3(out2)
|
||||
out4 = self.layer4(out3)
|
||||
|
||||
out = torch.cat([out2, out3, out4], dim=1)
|
||||
out = F.relu(self.conv(out))
|
||||
out = self.bn(self.pooling(out))
|
||||
out = self.linear(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
SIM = SpeakerSimilarity(
|
||||
sv_model_path=args.sv_model_path, ssl_model_path=args.ssl_model_path
|
||||
)
|
||||
score = SIM.score(args.eval_path, args.test_list)
|
||||
logging.info(f"SIM-o score: {score:.3f}")
|
@ -1,294 +0,0 @@
|
||||
"""
|
||||
Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system
|
||||
adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo
|
||||
|
||||
# Download model checkpoints
|
||||
wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt
|
||||
wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
import fairseq
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--wav-path", type=str, help="path of the evaluated speech directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--utmos-model-path",
|
||||
type=str,
|
||||
default="model/huggingface/utmos/utmos.pt",
|
||||
help="path of the UTMOS model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-model-path",
|
||||
type=str,
|
||||
default="model/huggingface/utmos/wav2vec_small.pt",
|
||||
help="path of the wav2vec SSL model",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
class UTMOSScore:
|
||||
"""Predicting score for each audio clip."""
|
||||
|
||||
def __init__(self, utmos_model_path, ssl_model_path):
|
||||
self.sample_rate = 16000
|
||||
self.device = (
|
||||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
)
|
||||
self.model = (
|
||||
BaselineLightningModule.load_from_checkpoint(
|
||||
utmos_model_path, ssl_model_path=ssl_model_path
|
||||
)
|
||||
.eval()
|
||||
.to(self.device)
|
||||
)
|
||||
|
||||
def score(self, wavs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
wavs: waveforms to be evaluated. When len(wavs) == 1 or 2,
|
||||
the model processes the input as a single audio clip. The model
|
||||
performs batch processing when len(wavs) == 3.
|
||||
"""
|
||||
if len(wavs.shape) == 1:
|
||||
out_wavs = wavs.unsqueeze(0).unsqueeze(0)
|
||||
elif len(wavs.shape) == 2:
|
||||
out_wavs = wavs.unsqueeze(0)
|
||||
elif len(wavs.shape) == 3:
|
||||
out_wavs = wavs
|
||||
else:
|
||||
raise ValueError("Dimension of input tensor needs to be <= 3.")
|
||||
bs = out_wavs.shape[0]
|
||||
batch = {
|
||||
"wav": out_wavs,
|
||||
"domains": torch.zeros(bs, dtype=torch.int).to(self.device),
|
||||
"judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288,
|
||||
}
|
||||
with torch.no_grad():
|
||||
output = self.model(batch)
|
||||
|
||||
return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3
|
||||
|
||||
def score_dir(self, dir, dtype="float32"):
|
||||
def _load_speech_task(fname, sample_rate):
|
||||
|
||||
wav_data, sr = sf.read(fname, dtype=dtype)
|
||||
if sr != sample_rate:
|
||||
wav_data = librosa.resample(
|
||||
wav_data, orig_sr=sr, target_sr=self.sample_rate
|
||||
)
|
||||
wav_data = torch.from_numpy(wav_data)
|
||||
|
||||
return wav_data
|
||||
|
||||
score_lst = []
|
||||
for fname in tqdm(os.listdir(dir)):
|
||||
speech = _load_speech_task(os.path.join(dir, fname), self.sample_rate)
|
||||
speech = speech.to(self.device)
|
||||
with torch.no_grad():
|
||||
score = self.score(speech)
|
||||
score_lst.append(score.item())
|
||||
return np.mean(score_lst)
|
||||
|
||||
|
||||
def load_ssl_model(ckpt_path="wav2vec_small.pt"):
|
||||
SSL_OUT_DIM = 768
|
||||
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||
[ckpt_path]
|
||||
)
|
||||
ssl_model = model[0]
|
||||
ssl_model.remove_pretraining_modules()
|
||||
return SSL_model(ssl_model, SSL_OUT_DIM)
|
||||
|
||||
|
||||
class BaselineLightningModule(pl.LightningModule):
|
||||
def __init__(self, ssl_model_path):
|
||||
super().__init__()
|
||||
self.construct_model(ssl_model_path)
|
||||
self.save_hyperparameters()
|
||||
|
||||
def construct_model(self, ssl_model_path):
|
||||
self.feature_extractors = nn.ModuleList(
|
||||
[
|
||||
load_ssl_model(ckpt_path=ssl_model_path),
|
||||
DomainEmbedding(3, 128),
|
||||
]
|
||||
)
|
||||
output_dim = sum(
|
||||
[
|
||||
feature_extractor.get_output_dim()
|
||||
for feature_extractor in self.feature_extractors
|
||||
]
|
||||
)
|
||||
output_layers = [
|
||||
LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim)
|
||||
]
|
||||
output_dim = output_layers[-1].get_output_dim()
|
||||
output_layers.append(
|
||||
Projection(
|
||||
hidden_dim=2048,
|
||||
activation=torch.nn.ReLU(),
|
||||
range_clipping=False,
|
||||
input_dim=output_dim,
|
||||
)
|
||||
)
|
||||
|
||||
self.output_layers = nn.ModuleList(output_layers)
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = {}
|
||||
for feature_extractor in self.feature_extractors:
|
||||
outputs.update(feature_extractor(inputs))
|
||||
x = outputs
|
||||
for output_layer in self.output_layers:
|
||||
x = output_layer(x, inputs)
|
||||
return x
|
||||
|
||||
|
||||
class SSL_model(nn.Module):
|
||||
def __init__(self, ssl_model, ssl_out_dim) -> None:
|
||||
super(SSL_model, self).__init__()
|
||||
self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
|
||||
|
||||
def forward(self, batch):
|
||||
wav = batch["wav"]
|
||||
wav = wav.squeeze(1) # [batches, wav_len]
|
||||
res = self.ssl_model(wav, mask=False, features_only=True)
|
||||
x = res["x"]
|
||||
return {"ssl-feature": x}
|
||||
|
||||
def get_output_dim(self):
|
||||
return self.ssl_out_dim
|
||||
|
||||
|
||||
class DomainEmbedding(nn.Module):
|
||||
def __init__(self, n_domains, domain_dim) -> None:
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(n_domains, domain_dim)
|
||||
self.output_dim = domain_dim
|
||||
|
||||
def forward(self, batch):
|
||||
return {"domain-feature": self.embedding(batch["domains"])}
|
||||
|
||||
def get_output_dim(self):
|
||||
return self.output_dim
|
||||
|
||||
|
||||
class LDConditioner(nn.Module):
|
||||
"""
|
||||
Conditions ssl output by listener embedding
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, judge_dim, num_judges=None):
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.judge_dim = judge_dim
|
||||
self.num_judges = num_judges
|
||||
assert num_judges != None
|
||||
self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
|
||||
# concat [self.output_layer, phoneme features]
|
||||
|
||||
self.decoder_rnn = nn.LSTM(
|
||||
input_size=self.input_dim + self.judge_dim,
|
||||
hidden_size=512,
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
bidirectional=True,
|
||||
) # linear?
|
||||
self.out_dim = self.decoder_rnn.hidden_size * 2
|
||||
|
||||
def get_output_dim(self):
|
||||
return self.out_dim
|
||||
|
||||
def forward(self, x, batch):
|
||||
judge_ids = batch["judge_id"]
|
||||
if "phoneme-feature" in x.keys():
|
||||
concatenated_feature = torch.cat(
|
||||
(
|
||||
x["ssl-feature"],
|
||||
x["phoneme-feature"]
|
||||
.unsqueeze(1)
|
||||
.expand(-1, x["ssl-feature"].size(1), -1),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
else:
|
||||
concatenated_feature = x["ssl-feature"]
|
||||
if "domain-feature" in x.keys():
|
||||
concatenated_feature = torch.cat(
|
||||
(
|
||||
concatenated_feature,
|
||||
x["domain-feature"]
|
||||
.unsqueeze(1)
|
||||
.expand(-1, concatenated_feature.size(1), -1),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
if judge_ids != None:
|
||||
concatenated_feature = torch.cat(
|
||||
(
|
||||
concatenated_feature,
|
||||
self.judge_embedding(judge_ids)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, concatenated_feature.size(1), -1),
|
||||
),
|
||||
dim=2,
|
||||
)
|
||||
decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
|
||||
return decoder_output
|
||||
|
||||
|
||||
class Projection(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
|
||||
super(Projection, self).__init__()
|
||||
self.range_clipping = range_clipping
|
||||
output_dim = 1
|
||||
if range_clipping:
|
||||
self.proj = nn.Tanh()
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_dim),
|
||||
activation,
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(hidden_dim, output_dim),
|
||||
)
|
||||
self.output_dim = output_dim
|
||||
|
||||
def forward(self, x, batch):
|
||||
output = self.net(x)
|
||||
|
||||
# range clipping
|
||||
if self.range_clipping:
|
||||
return self.proj(output) * 2.0 + 3
|
||||
else:
|
||||
return output
|
||||
|
||||
def get_output_dim(self):
|
||||
return self.output_dim
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
UTMOS = UTMOSScore(
|
||||
utmos_model_path=args.utmos_model_path, ssl_model_path=args.ssl_model_path
|
||||
)
|
||||
score = UTMOS.score_dir(args.wav_path)
|
||||
logging.info(f"UTMOS score: {score:.2f}")
|
@ -1,172 +0,0 @@
|
||||
"""
|
||||
Calculate WER with Hubert models.
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from jiwer import compute_measures
|
||||
from tqdm import tqdm
|
||||
from transformers import pipeline
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
|
||||
parser.add_argument(
|
||||
"--decode-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path of the output file of WER information",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path of the local hubert model, e.g., model/huggingface/hubert-large-ls960-ft",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-list",
|
||||
type=str,
|
||||
default="test.tsv",
|
||||
help="path of the transcript tsv file, where the first column "
|
||||
"is the wav name and the last column is the transcript",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=16, help="decoding batch size"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def post_process(text: str):
|
||||
text = text.replace("‘", "'")
|
||||
text = text.replace("’", "'")
|
||||
text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def process_one(hypo, truth):
|
||||
truth = post_process(truth)
|
||||
hypo = post_process(hypo)
|
||||
|
||||
measures = compute_measures(truth, hypo)
|
||||
word_num = len(truth.split(" "))
|
||||
wer = measures["wer"]
|
||||
subs = measures["substitutions"]
|
||||
dele = measures["deletions"]
|
||||
inse = measures["insertions"]
|
||||
return (truth, hypo, wer, subs, dele, inse, word_num)
|
||||
|
||||
|
||||
class SpeechEvalDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, wav_path: str, test_list: str):
|
||||
super().__init__()
|
||||
self.wav_name = []
|
||||
self.wav_paths = []
|
||||
self.transcripts = []
|
||||
with Path(test_list).open("r", encoding="utf8") as f:
|
||||
meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
|
||||
for item in meta:
|
||||
self.wav_name.append(item[0])
|
||||
self.wav_paths.append(Path(wav_path, item[0] + ".wav"))
|
||||
self.transcripts.append(item[-1])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.wav_paths)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
wav, sampling_rate = sf.read(self.wav_paths[index])
|
||||
item = {
|
||||
"array": librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000),
|
||||
"sampling_rate": 16000,
|
||||
"reference": self.transcripts[index],
|
||||
"wav_name": self.wav_name[index],
|
||||
}
|
||||
return item
|
||||
|
||||
|
||||
def main(test_list, wav_path, model_path, decode_path, batch_size, device):
|
||||
|
||||
if model_path is not None:
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model=model_path,
|
||||
device=device,
|
||||
tokenizer=model_path,
|
||||
)
|
||||
else:
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model="facebook/hubert-large-ls960-ft",
|
||||
device=device,
|
||||
)
|
||||
|
||||
dataset = SpeechEvalDataset(wav_path, test_list)
|
||||
|
||||
bar = tqdm(
|
||||
pipe(
|
||||
dataset,
|
||||
generate_kwargs={"language": "english", "task": "transcribe"},
|
||||
batch_size=batch_size,
|
||||
),
|
||||
total=len(dataset),
|
||||
)
|
||||
|
||||
wers = []
|
||||
inses = []
|
||||
deles = []
|
||||
subses = []
|
||||
word_nums = 0
|
||||
if decode_path:
|
||||
decode_dir = os.path.dirname(decode_path)
|
||||
if not os.path.exists(decode_dir):
|
||||
os.makedirs(decode_dir)
|
||||
fout = open(decode_path, "w")
|
||||
for out in bar:
|
||||
wav_name = out["wav_name"][0]
|
||||
transcription = post_process(out["text"].strip())
|
||||
text_ref = post_process(out["reference"][0].strip())
|
||||
truth, hypo, wer, subs, dele, inse, word_num = process_one(
|
||||
transcription, text_ref
|
||||
)
|
||||
if decode_path:
|
||||
fout.write(f"{wav_name}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
|
||||
wers.append(float(wer))
|
||||
inses.append(float(inse))
|
||||
deles.append(float(dele))
|
||||
subses.append(float(subs))
|
||||
word_nums += word_num
|
||||
|
||||
wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
|
||||
subs = round(np.mean(subses) * 100, 3)
|
||||
dele = round(np.mean(deles) * 100, 3)
|
||||
inse = round(np.mean(inses) * 100, 3)
|
||||
print(f"WER: {wer}%\n")
|
||||
if decode_path:
|
||||
fout.write(f"WER: {wer}%\n")
|
||||
fout.flush()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
main(
|
||||
args.test_list,
|
||||
args.wav_path,
|
||||
args.model_path,
|
||||
args.decode_path,
|
||||
args.batch_size,
|
||||
device,
|
||||
)
|
@ -1,181 +0,0 @@
|
||||
"""
|
||||
Calculate WER with Whisper-large-v3 or Paraformer models,
|
||||
following Seed-TTS https://github.com/BytedanceSpeech/seed-tts-eval
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import string
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import zhconv
|
||||
from funasr import AutoModel
|
||||
from jiwer import compute_measures
|
||||
from tqdm import tqdm
|
||||
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
||||
from zhon.hanzi import punctuation
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--wav-path", type=str, help="path of the speech directory")
|
||||
parser.add_argument(
|
||||
"--decode-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path of the output file of WER information",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path of the local whisper and paraformer model, "
|
||||
"e.g., whisper: model/huggingface/whisper-large-v3/, "
|
||||
"paraformer: model/huggingface/paraformer-zh/",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-list",
|
||||
type=str,
|
||||
default="test.tsv",
|
||||
help="path of the transcript tsv file, where the first column "
|
||||
"is the wav name and the last column is the transcript",
|
||||
)
|
||||
parser.add_argument("--lang", type=str, help="decoded language, zh or en")
|
||||
return parser
|
||||
|
||||
|
||||
def load_en_model(model_path):
|
||||
if model_path is None:
|
||||
model_path = "openai/whisper-large-v3"
|
||||
processor = WhisperProcessor.from_pretrained(model_path)
|
||||
model = WhisperForConditionalGeneration.from_pretrained(model_path)
|
||||
return processor, model
|
||||
|
||||
|
||||
def load_zh_model(model_path):
|
||||
if model_path is None:
|
||||
model_path = "paraformer-zh"
|
||||
model = AutoModel(model=model_path)
|
||||
return model
|
||||
|
||||
|
||||
def process_one(hypo, truth, lang):
|
||||
punctuation_all = punctuation + string.punctuation
|
||||
for x in punctuation_all:
|
||||
if x == "'":
|
||||
continue
|
||||
truth = truth.replace(x, "")
|
||||
hypo = hypo.replace(x, "")
|
||||
|
||||
truth = truth.replace(" ", " ")
|
||||
hypo = hypo.replace(" ", " ")
|
||||
|
||||
if lang == "zh":
|
||||
truth = " ".join([x for x in truth])
|
||||
hypo = " ".join([x for x in hypo])
|
||||
elif lang == "en":
|
||||
truth = truth.lower()
|
||||
hypo = hypo.lower()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
measures = compute_measures(truth, hypo)
|
||||
word_num = len(truth.split(" "))
|
||||
wer = measures["wer"]
|
||||
subs = measures["substitutions"]
|
||||
dele = measures["deletions"]
|
||||
inse = measures["insertions"]
|
||||
return (truth, hypo, wer, subs, dele, inse, word_num)
|
||||
|
||||
|
||||
def main(test_list, wav_path, model_path, decode_path, lang, device):
|
||||
if lang == "en":
|
||||
processor, model = load_en_model(model_path)
|
||||
model.to(device)
|
||||
elif lang == "zh":
|
||||
model = load_zh_model(model_path)
|
||||
params = []
|
||||
for line in open(test_list).readlines():
|
||||
line = line.strip()
|
||||
items = line.split("\t")
|
||||
wav_name, text_ref = items[0], items[-1]
|
||||
file_path = os.path.join(wav_path, wav_name + ".wav")
|
||||
assert os.path.exists(file_path), f"{file_path}"
|
||||
|
||||
params.append((file_path, text_ref))
|
||||
wers = []
|
||||
inses = []
|
||||
deles = []
|
||||
subses = []
|
||||
word_nums = 0
|
||||
if decode_path:
|
||||
decode_dir = os.path.dirname(decode_path)
|
||||
if not os.path.exists(decode_dir):
|
||||
os.makedirs(decode_dir)
|
||||
fout = open(decode_path, "w")
|
||||
for wav_path, text_ref in tqdm(params):
|
||||
if lang == "en":
|
||||
wav, sr = sf.read(wav_path)
|
||||
if sr != 16000:
|
||||
wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr))
|
||||
input_features = processor(
|
||||
wav, sampling_rate=16000, return_tensors="pt"
|
||||
).input_features
|
||||
input_features = input_features.to(device)
|
||||
forced_decoder_ids = processor.get_decoder_prompt_ids(
|
||||
language="english", task="transcribe"
|
||||
)
|
||||
predicted_ids = model.generate(
|
||||
input_features, forced_decoder_ids=forced_decoder_ids
|
||||
)
|
||||
transcription = processor.batch_decode(
|
||||
predicted_ids, skip_special_tokens=True
|
||||
)[0]
|
||||
elif lang == "zh":
|
||||
res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True)
|
||||
transcription = res[0]["text"]
|
||||
transcription = zhconv.convert(transcription, "zh-cn")
|
||||
|
||||
truth, hypo, wer, subs, dele, inse, word_num = process_one(
|
||||
transcription, text_ref, lang
|
||||
)
|
||||
if decode_path:
|
||||
fout.write(f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
|
||||
wers.append(float(wer))
|
||||
inses.append(float(inse))
|
||||
deles.append(float(dele))
|
||||
subses.append(float(subs))
|
||||
word_nums += word_num
|
||||
|
||||
wer_avg = round(np.mean(wers) * 100, 3)
|
||||
wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
|
||||
subs = round(np.mean(subses) * 100, 3)
|
||||
dele = round(np.mean(deles) * 100, 3)
|
||||
inse = round(np.mean(inses) * 100, 3)
|
||||
print(f"Seed-TTS WER: {wer_avg}%\n")
|
||||
print(f"WER: {wer}%\n")
|
||||
if decode_path:
|
||||
fout.write(f"SeedTTS WER: {wer_avg}%\n")
|
||||
fout.write(f"WER: {wer}%\n")
|
||||
fout.flush()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
main(
|
||||
args.test_list,
|
||||
args.wav_path,
|
||||
args.model_path,
|
||||
args.decode_path,
|
||||
args.lang,
|
||||
device,
|
||||
)
|
@ -1 +0,0 @@
|
||||
../zipvoice/feature.py
|
File diff suppressed because it is too large
Load Diff
@ -1,90 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file generates the file that maps tokens to IDs.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from piper_phonemize import get_espeak_map
|
||||
from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=Path,
|
||||
default=Path("data/tokens_emilia.txt"),
|
||||
help="Path to the dict that maps the text tokens to IDs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pinyin",
|
||||
type=Path,
|
||||
default=Path("local/pinyin.txt"),
|
||||
help="Path to the all unique pinyin",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_pinyin_tokens(pinyin: Path) -> List[str]:
|
||||
phones = set()
|
||||
with open(pinyin, "r") as f:
|
||||
for line in f:
|
||||
x = line.strip()
|
||||
initial = to_initials(x, strict=False)
|
||||
# don't want to share tokens with espeak tokens, so use tone3 style
|
||||
finals = to_finals_tone3(x, strict=False, neutral_tone_with_five=True)
|
||||
if initial != "":
|
||||
# don't want to share tokens with espeak tokens, so add a '0' after each initial
|
||||
phones.add(initial + "0")
|
||||
if finals != "":
|
||||
phones.add(finals)
|
||||
return sorted(phones)
|
||||
|
||||
|
||||
def get_token2id(args):
|
||||
"""Get a dict that maps token to IDs, and save it to the given filename."""
|
||||
all_tokens = get_espeak_map() # token: [token_id]
|
||||
all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()}
|
||||
# sort by token_id
|
||||
all_tokens = sorted(all_tokens.items(), key=lambda x: x[1])
|
||||
|
||||
all_pinyin = get_pinyin_tokens(args.pinyin)
|
||||
with open(args.tokens, "w", encoding="utf-8") as f:
|
||||
for token, token_id in all_tokens:
|
||||
f.write(f"{token} {token_id}\n")
|
||||
num_espeak_tokens = len(all_tokens)
|
||||
for i, pinyin in enumerate(all_pinyin):
|
||||
f.write(f"{pinyin} {num_espeak_tokens + i}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
get_token2id(args)
|
@ -1,31 +0,0 @@
|
||||
import re
|
||||
from collections import Counter
|
||||
|
||||
from lhotse import load_manifest_lazy
|
||||
|
||||
|
||||
def prepare_tokens(manifest_file, token_file):
|
||||
counter = Counter()
|
||||
manifest = load_manifest_lazy(manifest_file)
|
||||
for cut in manifest:
|
||||
line = re.sub(r"\s+", " ", cut.supervisions[0].text)
|
||||
counter.update(line)
|
||||
|
||||
unique_chars = set(counter.keys())
|
||||
|
||||
if "_" in unique_chars:
|
||||
unique_chars.remove("_")
|
||||
|
||||
sorted_chars = sorted(unique_chars, key=lambda char: counter[char], reverse=True)
|
||||
|
||||
result = ["_"] + sorted_chars
|
||||
|
||||
with open(token_file, "w", encoding="utf-8") as file:
|
||||
for index, char in enumerate(result):
|
||||
file.write(f"{char} {index}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
manifest_file = "data/fbank_libritts/libritts_cuts_train-all-shuf.jsonl.gz"
|
||||
output_token_file = "data/tokens_libritts.txt"
|
||||
prepare_tokens(manifest_file, output_token_file)
|
@ -1,155 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Zengwei Yao,
|
||||
# Zengrui Jin,
|
||||
# Wei Kang)
|
||||
# 2024 Tsinghua University (authors: Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file reads the texts in given manifest and save the cleaned new cuts.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor as Pool
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from tokenizer import (
|
||||
is_alphabet,
|
||||
is_chinese,
|
||||
is_hangul,
|
||||
is_japanese,
|
||||
tokenize_by_CJK_char,
|
||||
)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--subset",
|
||||
type=str,
|
||||
help="Subset of emilia, (ZH, EN, etc.)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jobs",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of jobs to processing.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--source-dir",
|
||||
type=str,
|
||||
default="data/manifests/splits_raw",
|
||||
help="The source directory of manifest files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dest-dir",
|
||||
type=str,
|
||||
default="data/manifests/splits",
|
||||
help="The destination directory of manifest files.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def preprocess_emilia(file_name: str, input_dir: Path, output_dir: Path):
|
||||
logging.info(f"Processing {file_name}")
|
||||
if (output_dir / file_name).is_file():
|
||||
logging.info(f"{file_name} exists, skipping.")
|
||||
return
|
||||
|
||||
def _filter_cut(cut):
|
||||
text = cut.supervisions[0].text
|
||||
duration = cut.supervisions[0].duration
|
||||
chinese = []
|
||||
english = []
|
||||
|
||||
# only contains chinese and space and alphabets
|
||||
clean_chars = []
|
||||
for x in text:
|
||||
if is_hangul(x):
|
||||
logging.warning(f"Delete cut with text containing Korean : {text}")
|
||||
return False
|
||||
if is_japanese(x):
|
||||
logging.warning(f"Delete cut with text containing Japanese : {text}")
|
||||
return False
|
||||
if is_chinese(x):
|
||||
chinese.append(x)
|
||||
clean_chars.append(x)
|
||||
if is_alphabet(x):
|
||||
english.append(x)
|
||||
clean_chars.append(x)
|
||||
if x == " ":
|
||||
clean_chars.append(x)
|
||||
if len(english) + len(chinese) == 0:
|
||||
logging.warning(f"Delete cut with text has no valid chars : {text}")
|
||||
return False
|
||||
|
||||
words = tokenize_by_CJK_char("".join(clean_chars))
|
||||
for i in range(len(words) - 10):
|
||||
if words[i : i + 10].count(words[i]) == 10:
|
||||
logging.warning(f"Delete cut with text with too much repeats : {text}")
|
||||
return False
|
||||
# word speed, 20 - 600 / minute
|
||||
if duration < len(words) / 600 * 60 or duration > len(words) / 20 * 60:
|
||||
logging.warning(
|
||||
f"Delete cut with audio text mismatch, duration : {duration}s, "
|
||||
f"words : {len(words)}, text : {text}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
try:
|
||||
cut_set = load_manifest_lazy(input_dir / file_name)
|
||||
cut_set = cut_set.filter(_filter_cut)
|
||||
cut_set.to_file(output_dir / file_name)
|
||||
except Exception as e:
|
||||
logging.error(f"Manifest {file_name} failed with error: {e}")
|
||||
os.remove(str(output_dir / file_name))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
|
||||
input_dir = Path(args.source_dir)
|
||||
output_dir = Path(args.dest_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cut_files = glob.glob(f"{args.source_dir}/emilia_cuts_{args.subset}.*.jsonl.gz")
|
||||
|
||||
with Pool(max_workers=args.jobs) as pool:
|
||||
futures = [
|
||||
pool.submit(
|
||||
preprocess_emilia, filename.split("/")[-1], input_dir, output_dir
|
||||
)
|
||||
for filename in cut_files
|
||||
]
|
||||
for f in futures:
|
||||
f.result()
|
||||
f.done()
|
||||
logging.info("Processing done.")
|
@ -1 +0,0 @@
|
||||
../zipvoice/tokenizer.py
|
@ -1,70 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script checks the following assumptions of the generated manifest:
|
||||
|
||||
- Single supervision per cut
|
||||
|
||||
We will add more checks later if needed.
|
||||
|
||||
Usage example:
|
||||
|
||||
python3 ./local/validate_manifest.py \
|
||||
./data/spectrogram/ljspeech_cuts_all.jsonl.gz
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset.speech_synthesis import validate_for_tts
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"manifest",
|
||||
type=Path,
|
||||
help="Path to the manifest file",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
manifest = args.manifest
|
||||
logging.info(f"Validating {manifest}")
|
||||
|
||||
assert manifest.is_file(), f"{manifest} does not exist"
|
||||
cut_set = load_manifest_lazy(manifest)
|
||||
assert isinstance(cut_set, CutSet), type(cut_set)
|
||||
|
||||
validate_for_tts(cut_set)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -1,17 +0,0 @@
|
||||
--find-links https://k2-fsa.github.io/icefall/piper_phonemize.html
|
||||
|
||||
torch
|
||||
torchaudio
|
||||
huggingface_hub
|
||||
lhotse
|
||||
safetensors
|
||||
vocos
|
||||
|
||||
# Normalization
|
||||
cn2an
|
||||
inflect
|
||||
|
||||
# Tokenization
|
||||
jieba
|
||||
piper_phonemize
|
||||
pypinyin
|
@ -1,102 +0,0 @@
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
export PYTHONWARNINGS=ignore
|
||||
export PYTHONPATH=../../:$PYTHONPATH
|
||||
|
||||
# Uncomment this if you have trouble connecting to HuggingFace
|
||||
# export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
start_stage=1
|
||||
end_stage=3
|
||||
|
||||
# Models used for SIM-o evaluation.
|
||||
# SV model wavlm_large_finetune.pth is downloaded from https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification
|
||||
# SSL model wavlm_large.pt is downloaded from https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt
|
||||
sv_model_path=model/UniSpeech/wavlm_large_finetune.pth
|
||||
wavlm_model_path=model/s3prl/wavlm_large.pt
|
||||
|
||||
# Models used for UTMOS evaluation.
|
||||
# wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt
|
||||
# wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt
|
||||
utmos_model_path=model/huggingface/utmos/utmos.pt
|
||||
wav2vec_model_path=model/huggingface/utmos/wav2vec_small.pt
|
||||
|
||||
|
||||
if [ $start_stage -le 1 ] && [ $end_stage -ge 1 ]; then
|
||||
|
||||
echo "=====Evaluate for Seed-TTS test-en======="
|
||||
test_list=testset/test_seedtts_en.tsv
|
||||
wav_path=results/zipvoice_seedtts_en
|
||||
|
||||
echo $wav_path
|
||||
echo "-----Computing SIM-o-----"
|
||||
python3 local/evaluate_sim.py \
|
||||
--sv-model-path ${sv_model_path} \
|
||||
--ssl-model-path ${wavlm_model_path} \
|
||||
--eval-path ${wav_path} \
|
||||
--test-list ${test_list}
|
||||
|
||||
echo "-----Computing WER-----"
|
||||
python3 local/evaluate_wer_seedtts.py \
|
||||
--test-list ${test_list} \
|
||||
--wav-path ${wav_path} \
|
||||
--lang "en"
|
||||
|
||||
echo "-----Computing UTSMOS-----"
|
||||
python3 local/evaluate_utmos.py \
|
||||
--wav-path ${wav_path} \
|
||||
--utmos-model-path ${utmos_model_path} \
|
||||
--ssl-model-path ${wav2vec_model_path}
|
||||
|
||||
fi
|
||||
|
||||
if [ $start_stage -le 2 ] && [ $end_stage -ge 2 ]; then
|
||||
echo "=====Evaluate for Seed-TTS test-zh======="
|
||||
test_list=testset/test_seedtts_zh.tsv
|
||||
wav_path=results/zipvoice_seedtts_zh
|
||||
|
||||
echo $wav_path
|
||||
echo "-----Computing SIM-o-----"
|
||||
python3 local/evaluate_sim.py \
|
||||
--sv-model-path ${sv_model_path} \
|
||||
--ssl-model-path ${wavlm_model_path} \
|
||||
--eval-path ${wav_path} \
|
||||
--test-list ${test_list}
|
||||
|
||||
echo "-----Computing WER-----"
|
||||
python3 local/evaluate_wer_seedtts.py \
|
||||
--test-list ${test_list} \
|
||||
--wav-path ${wav_path} \
|
||||
--lang "zh"
|
||||
|
||||
echo "-----Computing UTSMOS-----"
|
||||
python3 local/evaluate_utmos.py \
|
||||
--wav-path ${wav_path} \
|
||||
--utmos-model-path ${utmos_model_path} \
|
||||
--ssl-model-path ${wav2vec_model_path}
|
||||
fi
|
||||
|
||||
if [ $start_stage -le 3 ] && [ $end_stage -ge 3 ]; then
|
||||
echo "=====Evaluate for Librispeech test-clean======="
|
||||
test_list=testset/test_librispeech_pc_test_clean.tsv
|
||||
wav_path=results/zipvoice_librispeech_test_clean
|
||||
|
||||
echo $wav_path
|
||||
echo "-----Computing SIM-o-----"
|
||||
python3 local/evaluate_sim.py \
|
||||
--sv-model-path ${sv_model_path} \
|
||||
--ssl-model-path ${wavlm_model_path} \
|
||||
--eval-path ${wav_path} \
|
||||
--test-list ${test_list}
|
||||
|
||||
echo "-----Computing WER-----"
|
||||
python3 local/evaluate_wer_hubert.py \
|
||||
--test-list ${test_list} \
|
||||
--wav-path ${wav_path} \
|
||||
|
||||
echo "-----Computing UTSMOS-----"
|
||||
python3 local/evaluate_utmos.py \
|
||||
--wav-path ${wav_path} \
|
||||
--utmos-model-path ${utmos_model_path} \
|
||||
--ssl-model-path ${wav2vec_model_path}
|
||||
|
||||
fi
|
@ -1,126 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=0
|
||||
stop_stage=5
|
||||
sampling_rate=24000
|
||||
nj=32
|
||||
|
||||
dl_dir=$PWD/download
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# Your download directory should look like this:
|
||||
#
|
||||
# download/Amphion___Emilia
|
||||
# ├── metafile.yaml
|
||||
# ├── raw
|
||||
# │ ├── DE
|
||||
# │ ├── EN
|
||||
# │ ├── FR
|
||||
# │ ├── JA
|
||||
# │ ├── KO
|
||||
# │ ├── openemilia_45batches.tar.gz
|
||||
# │ ├── openemilia_all.tar.gz
|
||||
# │ └── ZH
|
||||
# └── README.md
|
||||
|
||||
if [ ! -d $dl_dir/Amphion___Emilia/raw ]; then
|
||||
log "Please refer https://openxlab.org.cn/datasets/Amphion/Emilia to download the dataset."
|
||||
exit(-1)
|
||||
fi
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare emilia manifests (EN and ZH only)"
|
||||
# We assume that you have downloaded the Emilia corpus
|
||||
# to $dl_dir/Amphion___Emilia
|
||||
# see stage 0 for the directory structure
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.emilia.done ]; then
|
||||
lhotse prepare emilia --lang en --num-jobs ${nj} $dl_dir/Amphion___Emilia data/manifests
|
||||
lhotse prepare emilia --lang zh --num-jobs ${nj} $dl_dir/Amphion___Emilia data/manifests
|
||||
touch data/manifests/.emilia.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Preprocess Emilia dataset, mainly for cleaning"
|
||||
mkdir -p data/manifests/splits_raw
|
||||
if [ ! -e data/manifests/split_raw/.emilia.split.done ]; then
|
||||
lhotse split-lazy data/manifests/emilia_cuts_EN.jsonl.gz data/manifests/splits_raw 10000
|
||||
lhotse split-lazy data/manifests/emilia_cuts_ZH.jsonl.gz data/manifests/splits_raw 10000
|
||||
touch data/manifests/splits_raw/.emilia.split.done
|
||||
fi
|
||||
|
||||
mkdir -p data/manifests/splits
|
||||
|
||||
if [ ! -e data/manifests/splits/.emilia.preprocess.done ]; then
|
||||
python local/preprocess_emilia.py --subset EN
|
||||
python local/preprocess_emilia.py --subset ZH
|
||||
touch data/manifests/splits/.emilia.preprocess.done
|
||||
fi
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Extract Fbank for Emilia"
|
||||
mkdir -p data/fbank/emilia_splits
|
||||
if [ ! -e data/fbank/emilia_splits/.emilia.fbank.done ]; then
|
||||
# You can speed up the extraction by distributing splits to multiple machines.
|
||||
for subset in EN ZH; do
|
||||
python local/compute_fbank.py \
|
||||
--source-dir data/manifests/splits \
|
||||
--dest-dir data/fbank/emilia_splits \
|
||||
--dataset emilia \
|
||||
--subset ${subset} \
|
||||
--splits-cuts 1 \
|
||||
--split-begin 0 \
|
||||
--split-end 2000 \
|
||||
--num-jobs ${nj}
|
||||
done
|
||||
touch data/fbank/emilia_splits/.emilia.fbank.done
|
||||
fi
|
||||
|
||||
if [ ! -e data/fbank/emilia_cuts_EN.jsonl.gz ]; then
|
||||
log "Combining EN fbank cuts and spliting EN dev set"
|
||||
gunzip -c data/fbank/emilia_splits/emilia_cuts_EN.*.jsonl.gz > data/fbank/emilia_cuts_EN.jsonl
|
||||
head -n 1500 data/fbank/emilia_cuts_EN.jsonl | gzip -c > data/fbank/emilia_cuts_EN_dev.jsonl.gz
|
||||
sed -i '1,1500d' data/fbank/emilia_cuts_EN.jsonl
|
||||
gzip data/fbank/emilia_cuts_EN.jsonl
|
||||
fi
|
||||
|
||||
if [ ! -e data/fbank/emilia_cuts_ZH.jsonl.gz ]; then
|
||||
log "Combining ZH fbank cuts and spliting ZH dev set"
|
||||
gunzip -c data/fbank/emilia_splits/emilia_cuts_ZH.*.jsonl.gz > data/fbank/emilia_cuts_ZH.jsonl
|
||||
head -n 1500 data/fbank/emilia_cuts_ZH.jsonl | gzip -c > data/fbank/emilia_cuts_ZH_dev.jsonl.gz
|
||||
sed -i '1,1500d' data/fbank/emilia_cuts_ZH.jsonl
|
||||
gzip data/fbank/emilia_cuts_ZH.jsonl
|
||||
fi
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Generate token file"
|
||||
if [ ! -e data/tokens_emilia.txt ]; then
|
||||
./local/prepare_token_file_emilia.py --tokens data/tokens_emilia.txt
|
||||
fi
|
||||
fi
|
@ -1,97 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=0
|
||||
stop_stage=5
|
||||
sampling_rate=24000
|
||||
nj=20
|
||||
|
||||
dl_dir=$PWD/download
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/LibriTTS,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS
|
||||
#
|
||||
if [ ! -d $dl_dir/LibriTTS ]; then
|
||||
lhotse download libritts $dl_dir
|
||||
fi
|
||||
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare LibriTTS manifest"
|
||||
# We assume that you have downloaded the LibriTTS corpus
|
||||
# to $dl_dir/LibriTTS
|
||||
mkdir -p data/manifests_libritts
|
||||
if [ ! -e data/manifests_libritts/.libritts.done ]; then
|
||||
lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests
|
||||
touch data/manifests/.libritts.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Compute Fbank for LibriTTS"
|
||||
mkdir -p data/fbank
|
||||
|
||||
if [ ! -e data/fbank/.libritts.done ]; then
|
||||
for subset in train-clean-100 train-clean-360 train-other-500 dev-clean test-clean; do
|
||||
python local/compute_fbank.py \
|
||||
--source-dir data/manifests \
|
||||
--dest-dir data/fbank \
|
||||
--dataset libritts \
|
||||
--subset ${subset} \
|
||||
--sampling-rate $sampling_rate \
|
||||
--num-jobs ${nj}
|
||||
done
|
||||
touch data/fbank/.libritts.done
|
||||
fi
|
||||
|
||||
# Here we shuffle and combine the train-clean-100, train-clean-360 and
|
||||
# train-other-500 together to form the training set.
|
||||
if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then
|
||||
cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \
|
||||
<(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \
|
||||
<(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \
|
||||
shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ ! -f data/fbank/libritts_cuts_train-clean-460.jsonl.gz ]; then
|
||||
cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \
|
||||
<(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) | \
|
||||
shuf | gzip -c > data/fbank/libritts_cuts_train-clean-460.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ ! -e data/fbank/.libritts-validated.done ]; then
|
||||
log "Validating data/fbank for LibriTTS"
|
||||
./local/validate_manifest.py \
|
||||
data/fbank/libritts_cuts_train-all-shuf.jsonl.gz
|
||||
touch data/fbank/.libritts-validated.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Generate token file"
|
||||
if [ ! -e data/tokens_libritts.txt ]; then
|
||||
./local/prepare_token_file_libritts.py --tokens data/tokens_libritts.txt
|
||||
fi
|
||||
fi
|
@ -1,142 +0,0 @@
|
||||
# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
|
||||
# use duck typing for LRScheduler since we have different possibilities, see
|
||||
# our class LRScheduler.
|
||||
LRSchedulerType = object
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
filename: Path,
|
||||
model: Union[nn.Module, DDP],
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
model_ema: Optional[nn.Module] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save training information to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
The checkpoint filename.
|
||||
model:
|
||||
The model to be saved. We only save its `state_dict()`.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
model_ema:
|
||||
The EMA version of model.
|
||||
params:
|
||||
User defined parameters, e.g., epoch, loss.
|
||||
optimizer:
|
||||
The optimizer to be saved. We only save its `state_dict()`.
|
||||
scheduler:
|
||||
The scheduler to be saved. We only save its `state_dict()`.
|
||||
scalar:
|
||||
The GradScaler to be saved. We only save its `state_dict()`.
|
||||
sampler:
|
||||
The sampler used in the labeled training dataset. We only
|
||||
save its `state_dict()`.
|
||||
rank:
|
||||
Used in DDP. We save checkpoint only for the node whose
|
||||
rank is 0.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
|
||||
logging.info(f"Saving checkpoint to {filename}")
|
||||
|
||||
if isinstance(model, DDP):
|
||||
model = model.module
|
||||
|
||||
checkpoint = {
|
||||
"model": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
||||
"scheduler": scheduler.state_dict() if scheduler is not None else None,
|
||||
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
||||
"sampler": sampler.state_dict() if sampler is not None else None,
|
||||
}
|
||||
|
||||
if model_avg is not None:
|
||||
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
|
||||
if model_ema is not None:
|
||||
checkpoint["model_ema"] = model_ema.to(torch.float32).state_dict()
|
||||
|
||||
if params:
|
||||
for k, v in params.items():
|
||||
assert k not in checkpoint
|
||||
checkpoint[k] = v
|
||||
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
filename: Path,
|
||||
model: Optional[nn.Module] = None,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
model_ema: Optional[nn.Module] = None,
|
||||
strict: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
logging.info(f"Loading checkpoint from {filename}")
|
||||
checkpoint = torch.load(filename, map_location="cpu")
|
||||
|
||||
if model is not None:
|
||||
|
||||
if next(iter(checkpoint["model"])).startswith("module."):
|
||||
logging.info("Loading checkpoint saved by DDP")
|
||||
|
||||
dst_state_dict = model.state_dict()
|
||||
src_state_dict = checkpoint["model"]
|
||||
for key in dst_state_dict.keys():
|
||||
src_key = "{}.{}".format("module", key)
|
||||
dst_state_dict[key] = src_state_dict.pop(src_key)
|
||||
assert len(src_state_dict) == 0
|
||||
model.load_state_dict(dst_state_dict, strict=strict)
|
||||
else:
|
||||
logging.info("Loading checkpoint")
|
||||
model.load_state_dict(checkpoint["model"], strict=strict)
|
||||
|
||||
checkpoint.pop("model")
|
||||
|
||||
if model_avg is not None and "model_avg" in checkpoint:
|
||||
logging.info("Loading averaged model")
|
||||
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
|
||||
checkpoint.pop("model_avg")
|
||||
|
||||
if model_ema is not None and "model_ema" in checkpoint:
|
||||
logging.info("Loading ema model")
|
||||
model_ema.load_state_dict(checkpoint["model_ema"], strict=strict)
|
||||
checkpoint.pop("model_ema")
|
||||
|
||||
return checkpoint
|
@ -1,135 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from lhotse.features.base import FeatureExtractor, register_extractor
|
||||
from lhotse.utils import Seconds, compute_num_frames
|
||||
|
||||
|
||||
class MelSpectrogramFeatures(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sampling_rate=24000,
|
||||
n_mels=100,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=sampling_rate,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
n_mels=n_mels,
|
||||
center=True,
|
||||
power=1,
|
||||
)
|
||||
|
||||
def forward(self, inp):
|
||||
assert len(inp.shape) == 2
|
||||
|
||||
mel = self.mel_spec(inp)
|
||||
logmel = mel.clamp(min=1e-7).log()
|
||||
return logmel
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchAudioFbankConfig:
|
||||
sampling_rate: int
|
||||
n_mels: int
|
||||
n_fft: int
|
||||
hop_length: int
|
||||
|
||||
|
||||
@register_extractor
|
||||
class TorchAudioFbank(FeatureExtractor):
|
||||
|
||||
name = "TorchAudioFbank"
|
||||
config_type = TorchAudioFbankConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config=config)
|
||||
|
||||
def _feature_fn(self, sample):
|
||||
fbank = MelSpectrogramFeatures(
|
||||
sampling_rate=self.config.sampling_rate,
|
||||
n_mels=self.config.n_mels,
|
||||
n_fft=self.config.n_fft,
|
||||
hop_length=self.config.hop_length,
|
||||
)
|
||||
|
||||
return fbank(sample)
|
||||
|
||||
@property
|
||||
def device(self) -> Union[str, torch.device]:
|
||||
return self.config.device
|
||||
|
||||
def feature_dim(self, sampling_rate: int) -> int:
|
||||
return self.config.n_mels
|
||||
|
||||
def extract(
|
||||
self,
|
||||
samples: Union[np.ndarray, torch.Tensor],
|
||||
sampling_rate: int,
|
||||
) -> Union[np.ndarray, torch.Tensor]:
|
||||
# Check for sampling rate compatibility.
|
||||
expected_sr = self.config.sampling_rate
|
||||
assert sampling_rate == expected_sr, (
|
||||
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
||||
f"got {sampling_rate}"
|
||||
)
|
||||
is_numpy = False
|
||||
if not isinstance(samples, torch.Tensor):
|
||||
samples = torch.from_numpy(samples)
|
||||
is_numpy = True
|
||||
|
||||
if len(samples.shape) == 1:
|
||||
samples = samples.unsqueeze(0)
|
||||
assert samples.ndim == 2, samples.shape
|
||||
assert samples.shape[0] == 1, samples.shape
|
||||
|
||||
mel = self._feature_fn(samples).squeeze().t()
|
||||
|
||||
assert mel.ndim == 2, mel.shape
|
||||
assert mel.shape[1] == self.config.n_mels, mel.shape
|
||||
|
||||
num_frames = compute_num_frames(
|
||||
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
|
||||
)
|
||||
|
||||
if mel.shape[0] > num_frames:
|
||||
mel = mel[:num_frames]
|
||||
elif mel.shape[0] < num_frames:
|
||||
mel = mel.unsqueeze(0)
|
||||
mel = torch.nn.functional.pad(
|
||||
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
||||
).squeeze(0)
|
||||
|
||||
if is_numpy:
|
||||
return mel.cpu().numpy()
|
||||
else:
|
||||
return mel
|
||||
|
||||
@property
|
||||
def frame_shift(self) -> Seconds:
|
||||
return self.config.hop_length / self.config.sampling_rate
|
@ -1,209 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
This script loads checkpoints and averages them.
|
||||
|
||||
(1) Average ZipVoice models before distill:
|
||||
python3 ./zipvoice/generate_averaged_model.py \
|
||||
--epoch 11 \
|
||||
--avg 4 \
|
||||
--distill 0 \
|
||||
--token-file data/tokens_emilia.txt \
|
||||
--exp-dir ./zipvoice/exp_zipvoice
|
||||
|
||||
It will generate a file `epoch-11-avg-14.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("epoch-11-avg-4.pt")`.
|
||||
|
||||
(2) Average ZipVoice-Distill models (the first stage model):
|
||||
|
||||
python3 ./zipvoice/generate_averaged_model.py \
|
||||
--iter 60000 \
|
||||
--avg 7 \
|
||||
--distill 1 \
|
||||
--token-file data/tokens_emilia.txt \
|
||||
--exp-dir ./zipvoice/exp_zipvoice_distill_1stage
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from model import get_distill_model, get_model
|
||||
from tokenizer import TokenizerEmilia, TokenizerLibriTTS
|
||||
from train_flow import add_model_arguments, get_params
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=11,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' or --iter",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipvoice/exp_zipvoice",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--distill",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use distill model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="emilia",
|
||||
choices=["emilia", "libritts"],
|
||||
help="The used training dataset for the model to inference",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
if params.dataset == "emilia":
|
||||
tokenizer = TokenizerEmilia(
|
||||
token_file=params.token_file, token_type=params.token_type
|
||||
)
|
||||
elif params.dataset == "libritts":
|
||||
tokenizer = TokenizerLibriTTS(
|
||||
token_file=params.token_file, token_type=params.token_type
|
||||
)
|
||||
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.pad_id = tokenizer.pad_id
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
print("Script started")
|
||||
|
||||
params.device = torch.device("cpu")
|
||||
print(f"Device: {params.device}")
|
||||
|
||||
print("About to create model")
|
||||
if params.distill:
|
||||
model = get_distill_model(params)
|
||||
else:
|
||||
model = get_model(params)
|
||||
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
print(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(params.device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=params.device,
|
||||
),
|
||||
strict=True,
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
print(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(params.device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=params.device,
|
||||
),
|
||||
strict=True,
|
||||
)
|
||||
if params.iter > 0:
|
||||
filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
|
||||
else:
|
||||
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,586 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
|
||||
# Han Zhu)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script loads checkpoints to generate waveforms.
|
||||
This script is supposed to be used with the model trained by yourself.
|
||||
If you want to use the pre-trained checkpoints provided by us, please refer to zipvoice_infer.py.
|
||||
|
||||
(1) Usage with a pre-trained checkpoint:
|
||||
|
||||
(a) ZipVoice model before distill:
|
||||
python3 zipvoice/infer.py \
|
||||
--checkpoint zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
|
||||
--distill 0 \
|
||||
--token-file "data/tokens_emilia.txt" \
|
||||
--test-list test.tsv \
|
||||
--res-dir results/test \
|
||||
--num-step 16 \
|
||||
--guidance-scale 1
|
||||
|
||||
(b) ZipVoice-Distill:
|
||||
python3 zipvoice/infer.py \
|
||||
--checkpoint zipvoice/exp_zipvoice_distill/checkpoint-2000.pt \
|
||||
--distill 1 \
|
||||
--token-file "data/tokens_emilia.txt" \
|
||||
--test-list test.tsv \
|
||||
--res-dir results/test_distill \
|
||||
--num-step 8 \
|
||||
--guidance-scale 3
|
||||
|
||||
(2) Usage with a directory of checkpoints (may requires checkpoint averaging):
|
||||
|
||||
(a) ZipVoice model before distill:
|
||||
python3 flow_match/infer.py \
|
||||
--exp-dir zipvoice/exp_zipvoice \
|
||||
--epoch 11 \
|
||||
--avg 4 \
|
||||
--distill 0 \
|
||||
--token-file "data/tokens_emilia.txt" \
|
||||
--test-list test.tsv \
|
||||
--res-dir results \
|
||||
--num-step 16 \
|
||||
--guidance-scale 1
|
||||
|
||||
(b) ZipVoice-Distill:
|
||||
python3 flow_match/infer.py \
|
||||
--exp-dir zipvoice/exp_zipvoice_distill/ \
|
||||
--iter 2000 \
|
||||
--avg 0 \
|
||||
--distill 1 \
|
||||
--token-file "data/tokens_emilia.txt" \
|
||||
--test-list test.tsv \
|
||||
--res-dir results \
|
||||
--num-step 8 \
|
||||
--guidance-scale 3
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from checkpoint import load_checkpoint
|
||||
from feature import TorchAudioFbank, TorchAudioFbankConfig
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import get_distill_model, get_model
|
||||
from tokenizer import TokenizerEmilia, TokenizerLibriTTS
|
||||
from train_flow import add_model_arguments, get_params
|
||||
from vocos import Vocos
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The checkpoint for inference. "
|
||||
"If it is None, will use checkpoints under exp_dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipvoice/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' or '--iter', avg=0 means no avg",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocoder-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The local vocos vocoder path, downloaded from huggingface, "
|
||||
"will download the vocodoer from huggingface if it is None.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--distill",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether it is a distilled TTS model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--test-list",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The list of prompt speech, prompt_transcription, "
|
||||
"and text to synthesize in the format of "
|
||||
"'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--res-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="Path name of the generated wavs dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="emilia",
|
||||
choices=["emilia", "libritts"],
|
||||
help="The used training dataset for the model to inference",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--guidance-scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The scale of classifier-free guidance during inference.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-step",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The number of sampling steps.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feat-scale",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="The scale factor of fbank feature",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--speed",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Control speech speed, 1.0 means normal, >1.0 means speed up",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--t-shift",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Shift t to smaller ones if t_shift < 1.0",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target-rms",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Target speech normalization rms value",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=666,
|
||||
help="Random seed",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_vocoder(vocos_local_path: Optional[str] = None):
|
||||
if vocos_local_path:
|
||||
vocos_local_path = "model/huggingface/vocos-mel-24khz/"
|
||||
vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
|
||||
state_dict = torch.load(
|
||||
f"{vocos_local_path}/pytorch_model.bin",
|
||||
weights_only=True,
|
||||
map_location="cpu",
|
||||
)
|
||||
vocoder.load_state_dict(state_dict)
|
||||
else:
|
||||
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
return vocoder
|
||||
|
||||
|
||||
def generate_sentence(
|
||||
save_path: str,
|
||||
prompt_text: str,
|
||||
prompt_wav: str,
|
||||
text: str,
|
||||
model: nn.Module,
|
||||
vocoder: nn.Module,
|
||||
tokenizer: TokenizerEmilia,
|
||||
feature_extractor: TorchAudioFbank,
|
||||
device: torch.device,
|
||||
num_step: int = 16,
|
||||
guidance_scale: float = 1.0,
|
||||
speed: float = 1.0,
|
||||
t_shift: float = 0.5,
|
||||
target_rms: float = 0.1,
|
||||
feat_scale: float = 0.1,
|
||||
sampling_rate: int = 24000,
|
||||
):
|
||||
"""
|
||||
Generate waveform of a text based on a given prompt
|
||||
waveform and its transcription.
|
||||
|
||||
Args:
|
||||
save_path (str): Path to save the generated wav.
|
||||
prompt_text (str): Transcription of the prompt wav.
|
||||
prompt_wav (str): Path to the prompt wav file.
|
||||
text (str): Text to be synthesized into a waveform.
|
||||
model (nn.Module): The model used for generation.
|
||||
vocoder (nn.Module): The vocoder used to convert features to waveforms.
|
||||
tokenizer (TokenizerEmilia): The tokenizer used to convert text to tokens.
|
||||
feature_extractor (TorchAudioFbank): The feature extractor used to
|
||||
extract acoustic features.
|
||||
device (torch.device): The device on which computations are performed.
|
||||
num_step (int, optional): Number of steps for decoding. Defaults to 16.
|
||||
guidance_scale (float, optional): Scale for classifier-free guidance.
|
||||
Defaults to 1.0.
|
||||
speed (float, optional): Speed control. Defaults to 1.0.
|
||||
t_shift (float, optional): Time shift. Defaults to 0.5.
|
||||
target_rms (float, optional): Target RMS for waveform normalization.
|
||||
Defaults to 0.1.
|
||||
feat_scale (float, optional): Scale for features.
|
||||
Defaults to 0.1.
|
||||
sampling_rate (int, optional): Sampling rate for the waveform.
|
||||
Defaults to 24000.
|
||||
Returns:
|
||||
metrics (dict): Dictionary containing time and real-time
|
||||
factor metrics for processing.
|
||||
"""
|
||||
# Convert text to tokens
|
||||
tokens = tokenizer.texts_to_token_ids([text])
|
||||
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
|
||||
|
||||
# Load and preprocess prompt wav
|
||||
prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
|
||||
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
|
||||
if prompt_rms < target_rms:
|
||||
prompt_wav = prompt_wav * target_rms / prompt_rms
|
||||
|
||||
if prompt_sampling_rate != sampling_rate:
|
||||
resampler = torchaudio.transforms.Resample(
|
||||
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
|
||||
)
|
||||
prompt_wav = resampler(prompt_wav)
|
||||
|
||||
# Extract features from prompt wav
|
||||
prompt_features = feature_extractor.extract(
|
||||
prompt_wav, sampling_rate=sampling_rate
|
||||
).to(device)
|
||||
prompt_features = prompt_features.unsqueeze(0) * feat_scale
|
||||
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
|
||||
|
||||
# Start timing
|
||||
start_t = dt.datetime.now()
|
||||
|
||||
# Generate features
|
||||
(
|
||||
pred_features,
|
||||
pred_features_lens,
|
||||
pred_prompt_features,
|
||||
pred_prompt_features_lens,
|
||||
) = model.sample(
|
||||
tokens=tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_features=prompt_features,
|
||||
prompt_features_lens=prompt_features_lens,
|
||||
speed=speed,
|
||||
t_shift=t_shift,
|
||||
duration="predict",
|
||||
num_step=num_step,
|
||||
guidance_scale=guidance_scale,
|
||||
)
|
||||
|
||||
# Postprocess predicted features
|
||||
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
|
||||
|
||||
# Start vocoder processing
|
||||
start_vocoder_t = dt.datetime.now()
|
||||
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
|
||||
|
||||
# Calculate processing times and real-time factors
|
||||
t = (dt.datetime.now() - start_t).total_seconds()
|
||||
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
|
||||
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
|
||||
wav_seconds = wav.shape[-1] / sampling_rate
|
||||
rtf = t / wav_seconds
|
||||
rtf_no_vocoder = t_no_vocoder / wav_seconds
|
||||
rtf_vocoder = t_vocoder / wav_seconds
|
||||
metrics = {
|
||||
"t": t,
|
||||
"t_no_vocoder": t_no_vocoder,
|
||||
"t_vocoder": t_vocoder,
|
||||
"wav_seconds": wav_seconds,
|
||||
"rtf": rtf,
|
||||
"rtf_no_vocoder": rtf_no_vocoder,
|
||||
"rtf_vocoder": rtf_vocoder,
|
||||
}
|
||||
|
||||
# Adjust wav volume if necessary
|
||||
if prompt_rms < target_rms:
|
||||
wav = wav * prompt_rms / target_rms
|
||||
wav = wav[0].cpu().numpy()
|
||||
sf.write(save_path, wav, sampling_rate)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def generate(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
vocoder: nn.Module,
|
||||
tokenizer: TokenizerEmilia,
|
||||
):
|
||||
total_t = []
|
||||
total_t_no_vocoder = []
|
||||
total_t_vocoder = []
|
||||
total_wav_seconds = []
|
||||
|
||||
config = TorchAudioFbankConfig(
|
||||
sampling_rate=params.sampling_rate,
|
||||
n_mels=100,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
)
|
||||
feature_extractor = TorchAudioFbank(config)
|
||||
|
||||
with open(params.test_list, "r") as fr:
|
||||
lines = fr.readlines()
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
|
||||
save_path = f"{params.wav_dir}/{wav_name}.wav"
|
||||
metrics = generate_sentence(
|
||||
save_path=save_path,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav=prompt_wav,
|
||||
text=text,
|
||||
model=model,
|
||||
vocoder=vocoder,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
device=params.device,
|
||||
num_step=params.num_step,
|
||||
guidance_scale=params.guidance_scale,
|
||||
speed=params.speed,
|
||||
t_shift=params.t_shift,
|
||||
target_rms=params.target_rms,
|
||||
feat_scale=params.feat_scale,
|
||||
sampling_rate=params.sampling_rate,
|
||||
)
|
||||
print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
|
||||
total_t.append(metrics["t"])
|
||||
total_t_no_vocoder.append(metrics["t_no_vocoder"])
|
||||
total_t_vocoder.append(metrics["t_vocoder"])
|
||||
total_wav_seconds.append(metrics["wav_seconds"])
|
||||
|
||||
print(f"Average RTF: " f"{np.sum(total_t)/np.sum(total_wav_seconds):.4f}")
|
||||
print(
|
||||
f"Average RTF w/o vocoder: "
|
||||
f"{np.sum(total_t_no_vocoder)/np.sum(total_wav_seconds):.4f}"
|
||||
)
|
||||
print(
|
||||
f"Average RTF vocoder: "
|
||||
f"{np.sum(total_t_vocoder)/np.sum(total_wav_seconds):.4f}"
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = (
|
||||
f"wavs-iter-{params.iter}-avg"
|
||||
f"-{params.avg}-step-{params.num_step}-scale-{params.guidance_scale}"
|
||||
)
|
||||
elif params.epoch > 0:
|
||||
params.suffix = (
|
||||
f"wavs-epoch-{params.epoch}-avg"
|
||||
f"-{params.avg}-step-{params.num_step}-scale-{params.guidance_scale}"
|
||||
)
|
||||
else:
|
||||
params.suffix = "wavs"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
params.device = torch.device("cuda", 0)
|
||||
else:
|
||||
params.device = torch.device("cpu")
|
||||
|
||||
logging.info(f"Device: {params.device}")
|
||||
|
||||
if params.dataset == "emilia":
|
||||
tokenizer = TokenizerEmilia(
|
||||
token_file=params.token_file, token_type=params.token_type
|
||||
)
|
||||
elif params.dataset == "libritts":
|
||||
tokenizer = TokenizerLibriTTS(
|
||||
token_file=params.token_file, token_type=params.token_type
|
||||
)
|
||||
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.pad_id = tokenizer.pad_id
|
||||
|
||||
logging.info(params)
|
||||
fix_random_seed(params.seed)
|
||||
|
||||
logging.info("About to create model")
|
||||
if params.distill:
|
||||
model = get_distill_model(params)
|
||||
else:
|
||||
model = get_model(params)
|
||||
|
||||
if params.checkpoint:
|
||||
load_checkpoint(params.checkpoint, model, strict=True)
|
||||
else:
|
||||
if params.avg == 0:
|
||||
if params.iter > 0:
|
||||
load_checkpoint(
|
||||
f"{params.exp_dir}/checkpoint-{params.iter}.pt", model, strict=True
|
||||
)
|
||||
else:
|
||||
load_checkpoint(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True
|
||||
)
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(params.device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=params.device,
|
||||
),
|
||||
strict=True,
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(params.device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=params.device,
|
||||
),
|
||||
strict=True,
|
||||
)
|
||||
|
||||
model = model.to(params.device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
vocoder = get_vocoder(params.vocoder_path)
|
||||
vocoder = vocoder.to(params.device)
|
||||
vocoder.eval()
|
||||
num_param = sum([p.numel() for p in vocoder.parameters()])
|
||||
logging.info(f"Number of vocoder parameters: {num_param}")
|
||||
|
||||
params.wav_dir = f"{params.res_dir}/{params.suffix}"
|
||||
os.makedirs(params.wav_dir, exist_ok=True)
|
||||
|
||||
assert (
|
||||
params.test_list is not None
|
||||
), "Please provide --test-list for speech synthesize."
|
||||
generate(
|
||||
params=params,
|
||||
model=model,
|
||||
vocoder=vocoder,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
main()
|
@ -1,578 +0,0 @@
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang
|
||||
# Han Zhu)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import ScheduledFloat
|
||||
from solver import EulerSolver
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from utils import (
|
||||
AttributeDict,
|
||||
condition_time_mask,
|
||||
get_tokens_index,
|
||||
make_pad_mask,
|
||||
pad_labels,
|
||||
prepare_avg_tokens_durations,
|
||||
to_int_tuple,
|
||||
)
|
||||
from zipformer import TTSZipformer
|
||||
|
||||
|
||||
def get_model(params: AttributeDict) -> nn.Module:
|
||||
"""Get the normal TTS model."""
|
||||
|
||||
fm_decoder = get_fm_decoder_model(params)
|
||||
text_encoder = get_text_encoder_model(params)
|
||||
|
||||
model = TtsModel(
|
||||
fm_decoder=fm_decoder,
|
||||
text_encoder=text_encoder,
|
||||
text_embed_dim=params.text_embed_dim,
|
||||
feat_dim=params.feat_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
pad_id=params.pad_id,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def get_distill_model(params: AttributeDict) -> nn.Module:
|
||||
"""Get the distillation TTS model."""
|
||||
|
||||
fm_decoder = get_fm_decoder_model(params, distill=True)
|
||||
text_encoder = get_text_encoder_model(params)
|
||||
|
||||
model = DistillTTSModelTrainWrapper(
|
||||
fm_decoder=fm_decoder,
|
||||
text_encoder=text_encoder,
|
||||
text_embed_dim=params.text_embed_dim,
|
||||
feat_dim=params.feat_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
pad_id=params.pad_id,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def get_fm_decoder_model(params: AttributeDict, distill: bool = False) -> nn.Module:
|
||||
"""Get the Zipformer-based FM decoder model."""
|
||||
|
||||
encoder = TTSZipformer(
|
||||
in_dim=params.feat_dim * 3,
|
||||
out_dim=params.feat_dim,
|
||||
downsampling_factor=to_int_tuple(params.fm_decoder_downsampling_factor),
|
||||
num_encoder_layers=to_int_tuple(params.fm_decoder_num_layers),
|
||||
cnn_module_kernel=to_int_tuple(params.fm_decoder_cnn_module_kernel),
|
||||
encoder_dim=params.fm_decoder_dim,
|
||||
feedforward_dim=params.fm_decoder_feedforward_dim,
|
||||
num_heads=params.fm_decoder_num_heads,
|
||||
query_head_dim=params.query_head_dim,
|
||||
pos_head_dim=params.pos_head_dim,
|
||||
value_head_dim=params.value_head_dim,
|
||||
pos_dim=params.pos_dim,
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
warmup_batches=4000.0,
|
||||
use_time_embed=True,
|
||||
time_embed_dim=192,
|
||||
use_guidance_scale_embed=distill,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
||||
def get_text_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
"""Get the Zipformer-based text encoder model."""
|
||||
|
||||
encoder = TTSZipformer(
|
||||
in_dim=params.text_embed_dim,
|
||||
out_dim=params.feat_dim,
|
||||
downsampling_factor=to_int_tuple(params.text_encoder_downsampling_factor),
|
||||
num_encoder_layers=to_int_tuple(params.text_encoder_num_layers),
|
||||
cnn_module_kernel=to_int_tuple(params.text_encoder_cnn_module_kernel),
|
||||
encoder_dim=params.text_encoder_dim,
|
||||
feedforward_dim=params.text_encoder_feedforward_dim,
|
||||
num_heads=params.text_encoder_num_heads,
|
||||
query_head_dim=params.query_head_dim,
|
||||
pos_head_dim=params.pos_head_dim,
|
||||
value_head_dim=params.value_head_dim,
|
||||
pos_dim=params.pos_dim,
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
warmup_batches=4000.0,
|
||||
use_time_embed=False,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
||||
class TtsModel(nn.Module):
|
||||
"""The normal TTS model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fm_decoder: nn.Module,
|
||||
text_encoder: nn.Module,
|
||||
text_embed_dim: int,
|
||||
feat_dim: int,
|
||||
vocab_size: int,
|
||||
pad_id: int = 0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
fm_decoder: the flow-matching encoder model, inputs are the
|
||||
input condition embeddings and noisy acoustic features,
|
||||
outputs are better acoustic features.
|
||||
text_encoder: the text encoder model. input are text
|
||||
embeddings, output are contextualized text embeddings.
|
||||
text_embed_dim: dimension of text embedding.
|
||||
feat_dim: dimension of acoustic features.
|
||||
vocab_size: vocabulary size.
|
||||
pad_id: padding id.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.feat_dim = feat_dim
|
||||
self.text_embed_dim = text_embed_dim
|
||||
self.pad_id = pad_id
|
||||
|
||||
self.fm_decoder = fm_decoder
|
||||
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
self.embed = nn.Embedding(vocab_size, text_embed_dim)
|
||||
|
||||
self.distill = False
|
||||
|
||||
def forward_fm_decoder(
|
||||
self,
|
||||
t: torch.Tensor,
|
||||
xt: torch.Tensor,
|
||||
text_condition: torch.Tensor,
|
||||
speech_condition: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
guidance_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Compute velocity.
|
||||
Args:
|
||||
t: A tensor of shape (N, 1, 1) or a tensor of a float,
|
||||
in the range of (0, 1).
|
||||
xt: the input of the current timestep, including condition
|
||||
embeddings and noisy acoustic features.
|
||||
text_condition: the text condition embeddings, with the
|
||||
shape (batch, seq_len, emb_dim).
|
||||
speech_condition: the speech condition embeddings, with the
|
||||
shape (batch, seq_len, emb_dim).
|
||||
padding_mask: The mask for padding, True means masked
|
||||
position, with the shape (N, T).
|
||||
guidance_scale: The guidance scale in classifier-free guidance,
|
||||
which is a tensor of shape (N, 1, 1) or a tensor of a float.
|
||||
|
||||
Returns:
|
||||
predicted velocity, with the shape (batch, seq_len, emb_dim).
|
||||
"""
|
||||
assert t.dim() in (0, 3)
|
||||
# Handle t with the shape (N, 1, 1):
|
||||
# squeeze the last dimension if it's size is 1.
|
||||
while t.dim() > 1 and t.size(-1) == 1:
|
||||
t = t.squeeze(-1)
|
||||
if guidance_scale is not None:
|
||||
while guidance_scale.dim() > 1 and guidance_scale.size(-1) == 1:
|
||||
guidance_scale = guidance_scale.squeeze(-1)
|
||||
# Handle t with a single value: expand to the size of batch size.
|
||||
if t.dim() == 0:
|
||||
t = t.repeat(xt.shape[0])
|
||||
if guidance_scale is not None and guidance_scale.dim() == 0:
|
||||
guidance_scale = guidance_scale.repeat(xt.shape[0])
|
||||
|
||||
xt = torch.cat([xt, text_condition, speech_condition], dim=2)
|
||||
vt = self.fm_decoder(
|
||||
x=xt, t=t, padding_mask=padding_mask, guidance_scale=guidance_scale
|
||||
)
|
||||
return vt
|
||||
|
||||
def forward_text_embed(
|
||||
self,
|
||||
tokens: List[List[int]],
|
||||
):
|
||||
"""
|
||||
Get the text embeddings.
|
||||
Args:
|
||||
tokens: a list of list of token ids.
|
||||
Returns:
|
||||
embed: the text embeddings, shape (batch, seq_len, emb_dim).
|
||||
tokens_lens: the length of each token sequence, shape (batch,).
|
||||
"""
|
||||
device = (
|
||||
self.device if isinstance(self, DDP) else next(self.parameters()).device
|
||||
)
|
||||
tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S)
|
||||
embed = self.embed(tokens_padded) # (B, S, C)
|
||||
tokens_lens = torch.tensor(
|
||||
[len(token) for token in tokens], dtype=torch.int64, device=device
|
||||
)
|
||||
tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S)
|
||||
|
||||
embed = self.text_encoder(
|
||||
x=embed, t=None, padding_mask=tokens_padding_mask
|
||||
) # (B, S, C)
|
||||
return embed, tokens_lens
|
||||
|
||||
def forward_text_condition(
|
||||
self,
|
||||
embed: torch.Tensor,
|
||||
tokens_lens: torch.Tensor,
|
||||
features_lens: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Get the text condition with the same length of the acoustic feature.
|
||||
Args:
|
||||
embed: the text embeddings, shape (batch, token_seq_len, emb_dim).
|
||||
tokens_lens: the length of each token sequence, shape (batch,).
|
||||
features_lens: the length of each acoustic feature sequence,
|
||||
shape (batch,).
|
||||
Returns:
|
||||
text_condition: the text condition, shape
|
||||
(batch, feature_seq_len, emb_dim).
|
||||
padding_mask: the padding mask of text condition, shape
|
||||
(batch, feature_seq_len).
|
||||
"""
|
||||
|
||||
num_frames = int(features_lens.max())
|
||||
|
||||
padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T)
|
||||
|
||||
tokens_durations = prepare_avg_tokens_durations(features_lens, tokens_lens)
|
||||
|
||||
tokens_index = get_tokens_index(tokens_durations, num_frames).to(
|
||||
embed.device
|
||||
) # (B, T)
|
||||
|
||||
text_condition = torch.gather(
|
||||
embed,
|
||||
dim=1,
|
||||
index=tokens_index.unsqueeze(-1).expand(
|
||||
embed.size(0), num_frames, embed.size(-1)
|
||||
),
|
||||
) # (B, T, F)
|
||||
return text_condition, padding_mask
|
||||
|
||||
def forward_text_train(
|
||||
self,
|
||||
tokens: List[List[int]],
|
||||
features_lens: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Process text for training, given text tokens and real feature lengths.
|
||||
"""
|
||||
embed, tokens_lens = self.forward_text_embed(tokens)
|
||||
text_condition, padding_mask = self.forward_text_condition(
|
||||
embed, tokens_lens, features_lens
|
||||
)
|
||||
return (
|
||||
text_condition,
|
||||
padding_mask,
|
||||
)
|
||||
|
||||
def forward_text_inference_gt_duration(
|
||||
self,
|
||||
tokens: List[List[int]],
|
||||
features_lens: torch.Tensor,
|
||||
prompt_tokens: List[List[int]],
|
||||
prompt_features_lens: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Process text for inference, given text tokens, real feature lengths and prompts.
|
||||
"""
|
||||
tokens = [
|
||||
prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
|
||||
]
|
||||
features_lens = prompt_features_lens + features_lens
|
||||
embed, tokens_lens = self.forward_text_embed(tokens)
|
||||
text_condition, padding_mask = self.forward_text_condition(
|
||||
embed, tokens_lens, features_lens
|
||||
)
|
||||
return text_condition, padding_mask
|
||||
|
||||
def forward_text_inference_ratio_duration(
|
||||
self,
|
||||
tokens: List[List[int]],
|
||||
prompt_tokens: List[List[int]],
|
||||
prompt_features_lens: torch.Tensor,
|
||||
speed: float,
|
||||
):
|
||||
"""
|
||||
Process text for inference, given text tokens and prompts,
|
||||
feature lengths are predicted with the ratio of token numbers.
|
||||
"""
|
||||
device = (
|
||||
self.device if isinstance(self, DDP) else next(self.parameters()).device
|
||||
)
|
||||
|
||||
cat_tokens = [
|
||||
prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
|
||||
]
|
||||
|
||||
prompt_tokens_lens = torch.tensor(
|
||||
[len(token) for token in prompt_tokens], dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
cat_embed, cat_tokens_lens = self.forward_text_embed(cat_tokens)
|
||||
|
||||
features_lens = torch.ceil(
|
||||
(prompt_features_lens / prompt_tokens_lens * cat_tokens_lens / speed)
|
||||
).to(dtype=torch.int64)
|
||||
|
||||
text_condition, padding_mask = self.forward_text_condition(
|
||||
cat_embed, cat_tokens_lens, features_lens
|
||||
)
|
||||
return text_condition, padding_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens: List[List[int]],
|
||||
features: torch.Tensor,
|
||||
features_lens: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
condition_drop_ratio: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass of the model for training.
|
||||
Args:
|
||||
tokens: a list of list of token ids.
|
||||
features: the acoustic features, with the shape (batch, seq_len, feat_dim).
|
||||
features_lens: the length of each acoustic feature sequence, shape (batch,).
|
||||
noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
|
||||
t: the time step, with the shape (batch, 1, 1).
|
||||
condition_drop_ratio: the ratio of dropped text condition.
|
||||
Returns:
|
||||
fm_loss: the flow-matching loss.
|
||||
"""
|
||||
|
||||
(text_condition, padding_mask,) = self.forward_text_train(
|
||||
tokens=tokens,
|
||||
features_lens=features_lens,
|
||||
)
|
||||
|
||||
speech_condition_mask = condition_time_mask(
|
||||
features_lens=features_lens,
|
||||
mask_percent=(0.7, 1.0),
|
||||
max_len=features.size(1),
|
||||
)
|
||||
speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
|
||||
|
||||
if condition_drop_ratio > 0.0:
|
||||
drop_mask = (
|
||||
torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
|
||||
> condition_drop_ratio
|
||||
)
|
||||
text_condition = text_condition * drop_mask
|
||||
|
||||
xt = features * t + noise * (1 - t)
|
||||
ut = features - noise # (B, T, F)
|
||||
|
||||
vt = self.forward_fm_decoder(
|
||||
t=t,
|
||||
xt=xt,
|
||||
text_condition=text_condition,
|
||||
speech_condition=speech_condition,
|
||||
padding_mask=padding_mask,
|
||||
)
|
||||
|
||||
loss_mask = speech_condition_mask & (~padding_mask)
|
||||
fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
|
||||
|
||||
return fm_loss
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokens: List[List[int]],
|
||||
prompt_tokens: List[List[int]],
|
||||
prompt_features: torch.Tensor,
|
||||
prompt_features_lens: torch.Tensor,
|
||||
features_lens: Optional[torch.Tensor] = None,
|
||||
speed: float = 1.0,
|
||||
t_shift: float = 1.0,
|
||||
duration: str = "predict",
|
||||
num_step: int = 5,
|
||||
guidance_scale: float = 0.5,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generate acoustic features, given text tokens, prompts feature
|
||||
and prompt transcription's text tokens.
|
||||
Args:
|
||||
tokens: a list of list of text tokens.
|
||||
prompt_tokens: a list of list of prompt tokens.
|
||||
prompt_features: the prompt feature with the shape
|
||||
(batch_size, seq_len, feat_dim).
|
||||
prompt_features_lens: the length of each prompt feature,
|
||||
with the shape (batch_size,).
|
||||
features_lens: the length of the predicted eature, with the
|
||||
shape (batch_size,). It is used only when duration is "real".
|
||||
duration: "real" or "predict". If "real", the predicted
|
||||
feature length is given by features_lens.
|
||||
num_step: the number of steps to use in the ODE solver.
|
||||
guidance_scale: the guidance scale for classifier-free guidance.
|
||||
distill: whether to use the distillation model for sampling.
|
||||
"""
|
||||
|
||||
assert duration in ["real", "predict"]
|
||||
|
||||
if duration == "predict":
|
||||
(
|
||||
text_condition,
|
||||
padding_mask,
|
||||
) = self.forward_text_inference_ratio_duration(
|
||||
tokens=tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_features_lens=prompt_features_lens,
|
||||
speed=speed,
|
||||
)
|
||||
else:
|
||||
assert features_lens is not None
|
||||
text_condition, padding_mask = self.forward_text_inference_gt_duration(
|
||||
tokens=tokens,
|
||||
features_lens=features_lens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_features_lens=prompt_features_lens,
|
||||
)
|
||||
batch_size, num_frames, _ = text_condition.shape
|
||||
|
||||
speech_condition = torch.nn.functional.pad(
|
||||
prompt_features, (0, 0, 0, num_frames - prompt_features.size(1))
|
||||
) # (B, T, F)
|
||||
|
||||
# False means speech condition positions.
|
||||
speech_condition_mask = make_pad_mask(prompt_features_lens, num_frames)
|
||||
speech_condition = torch.where(
|
||||
speech_condition_mask.unsqueeze(-1),
|
||||
torch.zeros_like(speech_condition),
|
||||
speech_condition,
|
||||
)
|
||||
|
||||
x0 = torch.randn(
|
||||
batch_size, num_frames, self.feat_dim, device=text_condition.device
|
||||
)
|
||||
solver = EulerSolver(self, distill=self.distill, func_name="forward_fm_decoder")
|
||||
|
||||
x1 = solver.sample(
|
||||
x=x0,
|
||||
text_condition=text_condition,
|
||||
speech_condition=speech_condition,
|
||||
padding_mask=padding_mask,
|
||||
num_step=num_step,
|
||||
guidance_scale=guidance_scale,
|
||||
t_shift=t_shift,
|
||||
)
|
||||
x1_wo_prompt_lens = (~padding_mask).sum(-1) - prompt_features_lens
|
||||
x1_prompt = torch.zeros(
|
||||
x1.size(0), prompt_features_lens.max(), x1.size(2), device=x1.device
|
||||
)
|
||||
x1_wo_prompt = torch.zeros(
|
||||
x1.size(0), x1_wo_prompt_lens.max(), x1.size(2), device=x1.device
|
||||
)
|
||||
for i in range(x1.size(0)):
|
||||
x1_wo_prompt[i, : x1_wo_prompt_lens[i], :] = x1[
|
||||
i,
|
||||
prompt_features_lens[i] : prompt_features_lens[i]
|
||||
+ x1_wo_prompt_lens[i],
|
||||
]
|
||||
x1_prompt[i, : prompt_features_lens[i], :] = x1[
|
||||
i, : prompt_features_lens[i]
|
||||
]
|
||||
|
||||
return x1_wo_prompt, x1_wo_prompt_lens, x1_prompt, prompt_features_lens
|
||||
|
||||
def sample_intermediate(
|
||||
self,
|
||||
tokens: List[List[int]],
|
||||
features: torch.Tensor,
|
||||
features_lens: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
speech_condition_mask: torch.Tensor,
|
||||
t_start: torch.Tensor,
|
||||
t_end: torch.Tensor,
|
||||
num_step: int = 1,
|
||||
guidance_scale: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generate acoustic features in intermediate timesteps.
|
||||
Args:
|
||||
tokens: List of list of token ids.
|
||||
features: The acoustic features, with the shape (batch, seq_len, feat_dim).
|
||||
features_lens: The length of each acoustic feature sequence,
|
||||
with the shape (batch,).
|
||||
noise: The initial noise, with the shape (batch, seq_len, feat_dim).
|
||||
speech_condition_mask: The mask for speech condition, True means
|
||||
non-condition positions, with the shape (batch, seq_len).
|
||||
t_start: The start timestep, with the shape (batch, 1, 1).
|
||||
t_end: The end timestep, with the shape (batch, 1, 1).
|
||||
num_step: The number of steps for sampling.
|
||||
guidance_scale: The scale for classifier-free guidance inference,
|
||||
with the shape (batch, 1, 1).
|
||||
distill: Whether to use distillation model.
|
||||
"""
|
||||
(text_condition, padding_mask,) = self.forward_text_train(
|
||||
tokens=tokens,
|
||||
features_lens=features_lens,
|
||||
)
|
||||
|
||||
speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
|
||||
|
||||
solver = EulerSolver(self, distill=self.distill, func_name="forward_fm_decoder")
|
||||
|
||||
x_t_end = solver.sample(
|
||||
x=noise,
|
||||
text_condition=text_condition,
|
||||
speech_condition=speech_condition,
|
||||
padding_mask=padding_mask,
|
||||
num_step=num_step,
|
||||
guidance_scale=guidance_scale,
|
||||
t_start=t_start,
|
||||
t_end=t_end,
|
||||
)
|
||||
x_t_end_lens = (~padding_mask).sum(-1)
|
||||
return x_t_end, x_t_end_lens
|
||||
|
||||
|
||||
class DistillTTSModelTrainWrapper(TtsModel):
|
||||
"""Wrapper for training the distilled TTS model."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.distill = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens: List[List[int]],
|
||||
features: torch.Tensor,
|
||||
features_lens: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
speech_condition_mask: torch.Tensor,
|
||||
t_start: torch.Tensor,
|
||||
t_end: torch.Tensor,
|
||||
num_step: int = 1,
|
||||
guidance_scale: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.sample_intermediate(
|
||||
tokens=tokens,
|
||||
features=features,
|
||||
features_lens=features_lens,
|
||||
noise=noise,
|
||||
speech_condition_mask=speech_condition_mask,
|
||||
t_start=t_start,
|
||||
t_end=t_end,
|
||||
num_step=num_step,
|
||||
guidance_scale=guidance_scale,
|
||||
)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,277 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class DiffusionModel(torch.nn.Module):
|
||||
"""A wrapper of diffusion models for inference.
|
||||
Args:
|
||||
model: The diffusion model.
|
||||
distill: Whether it is a distillation model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
distill: bool = False,
|
||||
func_name: str = "forward_fm_decoder",
|
||||
):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.distill = distill
|
||||
self.func_name = func_name
|
||||
self.model_func = getattr(self.model, func_name)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
t: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
text_condition: torch.Tensor,
|
||||
speech_condition: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
guidance_scale: Union[float, torch.Tensor] = 0.0,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward function that Handles the classifier-free guidance.
|
||||
Args:
|
||||
t: The current timestep, a tensor of shape (batch, 1, 1) or a tensor of a single float.
|
||||
x: The initial value, with the shape (batch, seq_len, emb_dim).
|
||||
text_condition: The text_condition of the diffision model, with the shape (batch, seq_len, emb_dim).
|
||||
speech_condition: The speech_condition of the diffision model, with the shape (batch, seq_len, emb_dim).
|
||||
padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len).
|
||||
guidance_scale: The scale of classifier-free guidance, a float or a tensor of shape (batch, 1, 1).
|
||||
Retrun:
|
||||
The prediction with the shape (batch, seq_len, emb_dim).
|
||||
"""
|
||||
if not torch.is_tensor(guidance_scale):
|
||||
guidance_scale = torch.tensor(
|
||||
guidance_scale, dtype=t.dtype, device=t.device
|
||||
)
|
||||
if self.distill:
|
||||
return self.model_func(
|
||||
t=t,
|
||||
xt=x,
|
||||
text_condition=text_condition,
|
||||
speech_condition=speech_condition,
|
||||
padding_mask=padding_mask,
|
||||
guidance_scale=guidance_scale,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if (guidance_scale == 0.0).all():
|
||||
return self.model_func(
|
||||
t=t,
|
||||
xt=x,
|
||||
text_condition=text_condition,
|
||||
speech_condition=speech_condition,
|
||||
padding_mask=padding_mask,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
if t.dim() != 0:
|
||||
t = torch.cat([t] * 2, dim=0)
|
||||
|
||||
x = torch.cat([x] * 2, dim=0)
|
||||
padding_mask = torch.cat([padding_mask] * 2, dim=0)
|
||||
|
||||
text_condition = torch.cat(
|
||||
[torch.zeros_like(text_condition), text_condition], dim=0
|
||||
)
|
||||
|
||||
if t.dim() == 0:
|
||||
if t > 0.5:
|
||||
speech_condition = torch.cat(
|
||||
[torch.zeros_like(speech_condition), speech_condition], dim=0
|
||||
)
|
||||
else:
|
||||
guidance_scale = guidance_scale * 2
|
||||
speech_condition = torch.cat(
|
||||
[speech_condition, speech_condition], dim=0
|
||||
)
|
||||
else:
|
||||
assert t.dim() > 0, t
|
||||
larger_t_index = (t > 0.5).squeeze(1).squeeze(1)
|
||||
zero_speech_condition = torch.cat(
|
||||
[torch.zeros_like(speech_condition), speech_condition], dim=0
|
||||
)
|
||||
speech_condition = torch.cat(
|
||||
[speech_condition, speech_condition], dim=0
|
||||
)
|
||||
speech_condition[larger_t_index] = zero_speech_condition[larger_t_index]
|
||||
guidance_scale[~larger_t_index[: larger_t_index.size(0) // 2]] *= 2
|
||||
|
||||
data_uncond, data_cond = self.model_func(
|
||||
t=t,
|
||||
xt=x,
|
||||
text_condition=text_condition,
|
||||
speech_condition=speech_condition,
|
||||
padding_mask=padding_mask,
|
||||
**kwargs
|
||||
).chunk(2, dim=0)
|
||||
|
||||
res = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond
|
||||
return res
|
||||
|
||||
|
||||
class EulerSolver:
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
distill: bool = False,
|
||||
func_name: str = "forward_fm_decoder",
|
||||
):
|
||||
"""Construct a Euler Solver
|
||||
Args:
|
||||
model: The diffusion model.
|
||||
distill: Whether it is distillation model.
|
||||
"""
|
||||
|
||||
self.model = DiffusionModel(model, distill=distill, func_name=func_name)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
text_condition: torch.Tensor,
|
||||
speech_condition: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
num_step: int = 10,
|
||||
guidance_scale: Union[float, torch.Tensor] = 0.0,
|
||||
t_start: Union[float, torch.Tensor] = 0.0,
|
||||
t_end: Union[float, torch.Tensor] = 1.0,
|
||||
t_shift: float = 1.0,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the sample at time `t_end` by Euler Solver.
|
||||
Args:
|
||||
x: The initial value at time `t_start`, with the shape (batch, seq_len, emb_dim).
|
||||
text_condition: The text condition of the diffision mode, with the shape (batch, seq_len, emb_dim).
|
||||
speech_condition: The speech condition of the diffision model, with the shape (batch, seq_len, emb_dim).
|
||||
padding_mask: The mask for padding; True means masked position, with the shape (batch, seq_len).
|
||||
num_step: The number of ODE steps.
|
||||
guidance_scale: The scale for classifier-free guidance, which is
|
||||
a float or a tensor with the shape (batch, 1, 1).
|
||||
t_start: the start timestep in the range of [0, 1],
|
||||
which is a float or a tensor with the shape (batch, 1, 1).
|
||||
t_end: the end time_step in the range of [0, 1],
|
||||
which is a float or a tensor with the shape (batch, 1, 1).
|
||||
t_shift: shift the t toward smaller numbers so that the sampling
|
||||
will emphasize low SNR region. Should be in the range of (0, 1].
|
||||
The shifting will be more significant when the number is smaller.
|
||||
|
||||
Returns:
|
||||
The approximated solution at time `t_end`.
|
||||
"""
|
||||
device = x.device
|
||||
|
||||
if torch.is_tensor(t_start) and t_start.dim() > 0:
|
||||
timesteps = get_time_steps_batch(
|
||||
t_start=t_start,
|
||||
t_end=t_end,
|
||||
num_step=num_step,
|
||||
t_shift=t_shift,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
timesteps = get_time_steps(
|
||||
t_start=t_start,
|
||||
t_end=t_end,
|
||||
num_step=num_step,
|
||||
t_shift=t_shift,
|
||||
device=device,
|
||||
)
|
||||
for step in range(num_step):
|
||||
v = self.model(
|
||||
t=timesteps[step],
|
||||
x=x,
|
||||
text_condition=text_condition,
|
||||
speech_condition=speech_condition,
|
||||
padding_mask=padding_mask,
|
||||
guidance_scale=guidance_scale,
|
||||
**kwargs
|
||||
)
|
||||
x = x + v * (timesteps[step + 1] - timesteps[step])
|
||||
return x
|
||||
|
||||
|
||||
def get_time_steps(
|
||||
t_start: float = 0.0,
|
||||
t_end: float = 1.0,
|
||||
num_step: int = 10,
|
||||
t_shift: float = 1.0,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""Compute the intermediate time steps for sampling.
|
||||
|
||||
Args:
|
||||
t_start: The starting time of the sampling (default is 0).
|
||||
t_end: The starting time of the sampling (default is 1).
|
||||
num_step: The number of sampling.
|
||||
t_shift: shift the t toward smaller numbers so that the sampling
|
||||
will emphasize low SNR region. Should be in the range of (0, 1].
|
||||
The shifting will be more significant when the number is smaller.
|
||||
device: A torch device.
|
||||
Returns:
|
||||
The time step with the shape (num_step + 1,).
|
||||
"""
|
||||
|
||||
timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device)
|
||||
|
||||
timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
|
||||
|
||||
return timesteps
|
||||
|
||||
|
||||
def get_time_steps_batch(
|
||||
t_start: torch.Tensor,
|
||||
t_end: torch.Tensor,
|
||||
num_step: int = 10,
|
||||
t_shift: float = 1.0,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""Compute the intermediate time steps for sampling in the batch mode.
|
||||
|
||||
Args:
|
||||
t_start: The starting time of the sampling (default is 0), with the shape (batch, 1, 1).
|
||||
t_end: The starting time of the sampling (default is 1), with the shape (batch, 1, 1).
|
||||
num_step: The number of sampling.
|
||||
t_shift: shift the t toward smaller numbers so that the sampling
|
||||
will emphasize low SNR region. Should be in the range of (0, 1].
|
||||
The shifting will be more significant when the number is smaller.
|
||||
device: A torch device.
|
||||
Returns:
|
||||
The time step with the shape (num_step + 1, N, 1, 1).
|
||||
"""
|
||||
while t_start.dim() > 1 and t_start.size(-1) == 1:
|
||||
t_start = t_start.squeeze(-1)
|
||||
while t_end.dim() > 1 and t_end.size(-1) == 1:
|
||||
t_end = t_end.squeeze(-1)
|
||||
assert t_start.dim() == t_end.dim() == 1
|
||||
|
||||
timesteps_shape = (num_step + 1, t_start.size(0))
|
||||
timesteps = torch.zeros(timesteps_shape, device=device)
|
||||
|
||||
for i in range(t_start.size(0)):
|
||||
timesteps[:, i] = torch.linspace(t_start[i], t_end[i], steps=num_step + 1)
|
||||
|
||||
timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
|
||||
|
||||
return timesteps.unsqueeze(-1).unsqueeze(-1)
|
@ -1,572 +0,0 @@
|
||||
# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao
|
||||
# Han Zhu,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import re
|
||||
import unicodedata
|
||||
from functools import reduce
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import cn2an
|
||||
import inflect
|
||||
import jieba
|
||||
from pypinyin import Style, lazy_pinyin
|
||||
from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
|
||||
|
||||
try:
|
||||
from piper_phonemize import phonemize_espeak
|
||||
except Exception as ex:
|
||||
raise RuntimeError(
|
||||
f"{ex}\nPlease run\n"
|
||||
"pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html"
|
||||
)
|
||||
|
||||
_inflect = inflect.engine()
|
||||
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
||||
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
||||
_percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
|
||||
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
||||
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
||||
_fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
|
||||
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [
|
||||
(re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
("dr", "doctor"),
|
||||
("st", "saint"),
|
||||
("co", "company"),
|
||||
("jr", "junior"),
|
||||
("maj", "major"),
|
||||
("gen", "general"),
|
||||
("drs", "doctors"),
|
||||
("rev", "reverend"),
|
||||
("lt", "lieutenant"),
|
||||
("hon", "honorable"),
|
||||
("sgt", "sergeant"),
|
||||
("capt", "captain"),
|
||||
("esq", "esquire"),
|
||||
("ltd", "limited"),
|
||||
("col", "colonel"),
|
||||
("ft", "fort"),
|
||||
("etc", "et cetera"),
|
||||
("btw", "by the way"),
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def intersperse(sequence, item=0):
|
||||
result = [item] * (len(sequence) * 2 + 1)
|
||||
result[1::2] = sequence
|
||||
return result
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(",", "")
|
||||
|
||||
|
||||
def _expand_decimal_point(m):
|
||||
return m.group(1).replace(".", " point ")
|
||||
|
||||
|
||||
def _expand_percent(m):
|
||||
return m.group(1).replace("%", " percent ")
|
||||
|
||||
|
||||
def _expand_dollars(m):
|
||||
match = m.group(1)
|
||||
parts = match.split(".")
|
||||
if len(parts) > 2:
|
||||
return " " + match + " dollars " # Unexpected format
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
return " %s %s " % (dollars, dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return " %s %s " % (cents, cent_unit)
|
||||
else:
|
||||
return " zero dollars "
|
||||
|
||||
|
||||
def fraction_to_words(numerator, denominator):
|
||||
if numerator == 1 and denominator == 2:
|
||||
return " one half "
|
||||
if numerator == 1 and denominator == 4:
|
||||
return " one quarter "
|
||||
if denominator == 2:
|
||||
return " " + _inflect.number_to_words(numerator) + " halves "
|
||||
if denominator == 4:
|
||||
return " " + _inflect.number_to_words(numerator) + " quarters "
|
||||
return (
|
||||
" "
|
||||
+ _inflect.number_to_words(numerator)
|
||||
+ " "
|
||||
+ _inflect.ordinal(_inflect.number_to_words(denominator))
|
||||
+ " "
|
||||
)
|
||||
|
||||
|
||||
def _expand_fraction(m):
|
||||
numerator = int(m.group(1))
|
||||
denominator = int(m.group(2))
|
||||
return fraction_to_words(numerator, denominator)
|
||||
|
||||
|
||||
def _expand_ordinal(m):
|
||||
return " " + _inflect.number_to_words(m.group(0)) + " "
|
||||
|
||||
|
||||
def _expand_number(m):
|
||||
num = int(m.group(0))
|
||||
if num > 1000 and num < 3000:
|
||||
if num == 2000:
|
||||
return " two thousand "
|
||||
elif num > 2000 and num < 2010:
|
||||
return " two thousand " + _inflect.number_to_words(num % 100) + " "
|
||||
elif num % 100 == 0:
|
||||
return " " + _inflect.number_to_words(num // 100) + " hundred "
|
||||
else:
|
||||
return (
|
||||
" "
|
||||
+ _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(
|
||||
", ", " "
|
||||
)
|
||||
+ " "
|
||||
)
|
||||
else:
|
||||
return " " + _inflect.number_to_words(num, andword="") + " "
|
||||
|
||||
|
||||
# Normalize numbers pronunciation
|
||||
def normalize_numbers(text):
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_pounds_re, r"\1 pounds", text)
|
||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||
text = re.sub(_fraction_re, _expand_fraction, text)
|
||||
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||
text = re.sub(_percent_number_re, _expand_percent, text)
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
return text
|
||||
|
||||
|
||||
# Convert numbers to Chinese pronunciation
|
||||
def number_to_chinese(text):
|
||||
text = cn2an.transform(text, "an2cn")
|
||||
return text
|
||||
|
||||
|
||||
def map_punctuations(text):
|
||||
text = text.replace(",", ",")
|
||||
text = text.replace("。", ".")
|
||||
text = text.replace("!", "!")
|
||||
text = text.replace("?", "?")
|
||||
text = text.replace(";", ";")
|
||||
text = text.replace(":", ":")
|
||||
text = text.replace("、", ",")
|
||||
text = text.replace("‘", "'")
|
||||
text = text.replace("“", '"')
|
||||
text = text.replace("”", '"')
|
||||
text = text.replace("’", "'")
|
||||
text = text.replace("⋯", "…")
|
||||
text = text.replace("···", "…")
|
||||
text = text.replace("・・・", "…")
|
||||
text = text.replace("...", "…")
|
||||
return text
|
||||
|
||||
|
||||
def is_chinese(char):
|
||||
if char >= "\u4e00" and char <= "\u9fa5":
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_alphabet(char):
|
||||
if (char >= "\u0041" and char <= "\u005a") or (
|
||||
char >= "\u0061" and char <= "\u007a"
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_hangul(char):
|
||||
letters = unicodedata.normalize("NFD", char)
|
||||
return all(
|
||||
["\u1100" <= c <= "\u11ff" or "\u3131" <= c <= "\u318e" for c in letters]
|
||||
)
|
||||
|
||||
|
||||
def is_japanese(char):
|
||||
return any(
|
||||
[
|
||||
start <= char <= end
|
||||
for start, end in [
|
||||
("\u3041", "\u3096"),
|
||||
("\u30a0", "\u30ff"),
|
||||
("\uff5f", "\uff9f"),
|
||||
("\u31f0", "\u31ff"),
|
||||
("\u3220", "\u3243"),
|
||||
("\u3280", "\u337f"),
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_segment(text: str) -> List[str]:
|
||||
# sentence --> [ch_part, en_part, ch_part, ...]
|
||||
# example :
|
||||
# input : 我们是小米人,是吗? Yes I think so!霍...啦啦啦
|
||||
# output : [('我们是小米人,是吗? ', 'zh'), ('Yes I think so!', 'en'), ('霍...啦啦啦', 'zh')]
|
||||
segments = []
|
||||
types = []
|
||||
flag = 0
|
||||
temp_seg = ""
|
||||
temp_lang = ""
|
||||
|
||||
for i, ch in enumerate(text):
|
||||
if is_chinese(ch):
|
||||
types.append("zh")
|
||||
elif is_alphabet(ch):
|
||||
types.append("en")
|
||||
else:
|
||||
types.append("other")
|
||||
|
||||
assert len(types) == len(text)
|
||||
|
||||
for i in range(len(types)):
|
||||
# find the first char of the seg
|
||||
if flag == 0:
|
||||
temp_seg += text[i]
|
||||
temp_lang = types[i]
|
||||
flag = 1
|
||||
else:
|
||||
if temp_lang == "other":
|
||||
if types[i] == temp_lang:
|
||||
temp_seg += text[i]
|
||||
else:
|
||||
temp_seg += text[i]
|
||||
temp_lang = types[i]
|
||||
else:
|
||||
if types[i] == temp_lang:
|
||||
temp_seg += text[i]
|
||||
elif types[i] == "other":
|
||||
temp_seg += text[i]
|
||||
else:
|
||||
segments.append((temp_seg, temp_lang))
|
||||
temp_seg = text[i]
|
||||
temp_lang = types[i]
|
||||
flag = 1
|
||||
|
||||
segments.append((temp_seg, temp_lang))
|
||||
return segments
|
||||
|
||||
|
||||
def preprocess(text: str) -> str:
|
||||
text = map_punctuations(text)
|
||||
return text
|
||||
|
||||
|
||||
def tokenize_ZH(text: str) -> List[str]:
|
||||
try:
|
||||
text = number_to_chinese(text)
|
||||
segs = list(jieba.cut(text))
|
||||
full = lazy_pinyin(
|
||||
segs, style=Style.TONE3, tone_sandhi=True, neutral_tone_with_five=True
|
||||
)
|
||||
phones = []
|
||||
for x in full:
|
||||
# valid pinyin (in tone3 style) is alphabet + 1 number in [1-5].
|
||||
if not (x[0:-1].isalpha() and x[-1] in ("1", "2", "3", "4", "5")):
|
||||
phones.append(x)
|
||||
continue
|
||||
initial = to_initials(x, strict=False)
|
||||
# don't want to share tokens with espeak tokens, so use tone3 style
|
||||
final = to_finals_tone3(x, strict=False, neutral_tone_with_five=True)
|
||||
if initial != "":
|
||||
# don't want to share tokens with espeak tokens, so add a '0' after each initial
|
||||
phones.append(initial + "0")
|
||||
if final != "":
|
||||
phones.append(final)
|
||||
return phones
|
||||
except Exception as ex:
|
||||
logging.warning(f"Tokenize ZH failed: {ex}")
|
||||
return []
|
||||
|
||||
|
||||
def tokenize_EN(text: str) -> List[str]:
|
||||
try:
|
||||
text = expand_abbreviations(text)
|
||||
text = normalize_numbers(text)
|
||||
tokens = phonemize_espeak(text, "en-us")
|
||||
tokens = reduce(lambda x, y: x + y, tokens)
|
||||
return tokens
|
||||
except Exception as ex:
|
||||
logging.warning(f"Tokenize EN failed: {ex}")
|
||||
return []
|
||||
|
||||
|
||||
class TokenizerEmilia(object):
|
||||
def __init__(self, token_file: Optional[str] = None, token_type="phone"):
|
||||
"""
|
||||
Args:
|
||||
tokens: the file that contains information that maps tokens to ids,
|
||||
which is a text file with '{token} {token_id}' per line.
|
||||
"""
|
||||
assert (
|
||||
token_type == "phone"
|
||||
), f"Only support phone tokenizer for Emilia, but get {token_type}."
|
||||
self.has_tokens = False
|
||||
if token_file is None:
|
||||
logging.debug(
|
||||
"Initialize Tokenizer without tokens file, will fail when map to ids."
|
||||
)
|
||||
return
|
||||
self.token2id: Dict[str, int] = {}
|
||||
with open(token_file, "r", encoding="utf-8") as f:
|
||||
for line in f.readlines():
|
||||
info = line.rstrip().split("\t")
|
||||
token, id = info[0], int(info[1])
|
||||
assert token not in self.token2id, token
|
||||
self.token2id[token] = id
|
||||
self.pad_id = self.token2id["_"] # padding
|
||||
|
||||
self.vocab_size = len(self.token2id)
|
||||
self.has_tokens = True
|
||||
|
||||
def texts_to_token_ids(
|
||||
self,
|
||||
texts: List[str],
|
||||
) -> List[List[int]]:
|
||||
return self.tokens_to_token_ids(self.texts_to_tokens(texts))
|
||||
|
||||
def texts_to_tokens(
|
||||
self,
|
||||
texts: List[str],
|
||||
) -> List[List[str]]:
|
||||
"""
|
||||
Args:
|
||||
texts:
|
||||
A list of transcripts.
|
||||
Returns:
|
||||
Return a list of a list of tokens [utterance][token]
|
||||
"""
|
||||
for i in range(len(texts)):
|
||||
# Text normalization
|
||||
texts[i] = preprocess(texts[i])
|
||||
|
||||
phoneme_list = []
|
||||
for text in texts:
|
||||
# now only en and ch
|
||||
segments = get_segment(text)
|
||||
all_phoneme = []
|
||||
for index in range(len(segments)):
|
||||
seg = segments[index]
|
||||
if seg[1] == "zh":
|
||||
phoneme = tokenize_ZH(seg[0])
|
||||
else:
|
||||
if seg[1] != "en":
|
||||
logging.warning(
|
||||
f"The lang should be en, given {seg[1]}, skipping segment : {seg}"
|
||||
)
|
||||
continue
|
||||
phoneme = tokenize_EN(seg[0])
|
||||
all_phoneme += phoneme
|
||||
phoneme_list.append(all_phoneme)
|
||||
return phoneme_list
|
||||
|
||||
def tokens_to_token_ids(
|
||||
self,
|
||||
tokens: List[List[str]],
|
||||
intersperse_blank: bool = False,
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Args:
|
||||
tokens_list:
|
||||
A list of token list, each corresponding to one utterance.
|
||||
intersperse_blank:
|
||||
Whether to intersperse blanks in the token sequence.
|
||||
|
||||
Returns:
|
||||
Return a list of token id list [utterance][token_id]
|
||||
"""
|
||||
assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
|
||||
token_ids = []
|
||||
|
||||
for tks in tokens:
|
||||
ids = []
|
||||
for t in tks:
|
||||
if t not in self.token2id:
|
||||
logging.warning(f"Skip OOV {t}")
|
||||
continue
|
||||
ids.append(self.token2id[t])
|
||||
|
||||
if intersperse_blank:
|
||||
ids = intersperse(ids, self.pad_id)
|
||||
|
||||
token_ids.append(ids)
|
||||
|
||||
return token_ids
|
||||
|
||||
|
||||
class TokenizerLibriTTS(object):
|
||||
def __init__(self, token_file: str, token_type: str):
|
||||
"""
|
||||
Args:
|
||||
type: the type of tokenizer, e.g., bpe, char, phone.
|
||||
tokens: the file that contains information that maps tokens to ids,
|
||||
which is a text file with '{token} {token_id}' per line if type is
|
||||
char or phone, otherwise it is a bpe_model file.
|
||||
"""
|
||||
self.type = token_type
|
||||
assert token_type in ["bpe", "char", "phone"]
|
||||
# Parse token file
|
||||
|
||||
if token_type == "bpe":
|
||||
import sentencepiece as spm
|
||||
|
||||
self.sp = spm.SentencePieceProcessor()
|
||||
self.sp.load(token_file)
|
||||
self.pad_id = self.sp.piece_to_id("<pad>")
|
||||
self.vocab_size = self.sp.get_piece_size()
|
||||
else:
|
||||
self.token2id: Dict[str, int] = {}
|
||||
with open(token_file, "r", encoding="utf-8") as f:
|
||||
for line in f.readlines():
|
||||
info = line.rstrip().split("\t")
|
||||
token, id = info[0], int(info[1])
|
||||
assert token not in self.token2id, token
|
||||
self.token2id[token] = id
|
||||
self.pad_id = self.token2id["_"] # padding
|
||||
self.vocab_size = len(self.token2id)
|
||||
try:
|
||||
from tacotron_cleaner.cleaners import custom_english_cleaners as cleaner
|
||||
except Exception as ex:
|
||||
raise RuntimeError(
|
||||
f"{ex}\nPlease run\n"
|
||||
"pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/"
|
||||
)
|
||||
self.cleaner = cleaner
|
||||
|
||||
def texts_to_token_ids(
|
||||
self,
|
||||
texts: List[str],
|
||||
lang: str = "en-us",
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Args:
|
||||
texts:
|
||||
A list of transcripts.
|
||||
intersperse_blank:
|
||||
Whether to intersperse blanks in the token sequence.
|
||||
Used when alignment is from MAS.
|
||||
lang:
|
||||
Language argument passed to phonemize_espeak().
|
||||
|
||||
Returns:
|
||||
Return a list of token id list [utterance][token_id]
|
||||
"""
|
||||
for i in range(len(texts)):
|
||||
# Text normalization
|
||||
texts[i] = self.cleaner(texts[i])
|
||||
|
||||
if self.type == "bpe":
|
||||
token_ids_list = self.sp.encode(texts)
|
||||
|
||||
elif self.type == "phone":
|
||||
token_ids_list = []
|
||||
for text in texts:
|
||||
tokens_list = phonemize_espeak(text.lower(), lang)
|
||||
tokens = []
|
||||
for t in tokens_list:
|
||||
tokens.extend(t)
|
||||
token_ids = []
|
||||
for t in tokens:
|
||||
if t not in self.token2id:
|
||||
logging.warning(f"Skip OOV {t}")
|
||||
continue
|
||||
token_ids.append(self.token2id[t])
|
||||
|
||||
token_ids_list.append(token_ids)
|
||||
else:
|
||||
token_ids_list = []
|
||||
for text in texts:
|
||||
token_ids = []
|
||||
for t in text:
|
||||
if t not in self.token2id:
|
||||
logging.warning(f"Skip OOV {t}")
|
||||
continue
|
||||
token_ids.append(self.token2id[t])
|
||||
|
||||
token_ids_list.append(token_ids)
|
||||
|
||||
return token_ids_list
|
||||
|
||||
def tokens_to_token_ids(
|
||||
self,
|
||||
tokens_list: List[str],
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Args:
|
||||
tokens_list:
|
||||
A list of token list, each corresponding to one utterance.
|
||||
|
||||
Returns:
|
||||
Return a list of token id list [utterance][token_id]
|
||||
"""
|
||||
token_ids_list = []
|
||||
|
||||
for tokens in tokens_list:
|
||||
token_ids = []
|
||||
for t in tokens:
|
||||
if t not in self.token2id:
|
||||
logging.warning(f"Skip OOV {t}")
|
||||
continue
|
||||
token_ids.append(self.token2id[t])
|
||||
|
||||
token_ids_list.append(token_ids)
|
||||
|
||||
return token_ids_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = "我们是5年小米人,是吗? Yes I think so! mr king, 5 years, from 2019 to 2024. 霍...啦啦啦超过90%的人咯...?!9204"
|
||||
tokenizer = TokenizerEmilia()
|
||||
tokens = tokenizer.texts_to_tokens([text])
|
||||
print(f"tokens : {tokens}")
|
||||
tokens2 = "|".join(tokens[0])
|
||||
print(f"tokens2 : {tokens2}")
|
||||
tokens2 = tokens2.split("|")
|
||||
assert tokens[0] == tokens2
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,456 +0,0 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||
# Zengwei Yao,
|
||||
# Zengrui Jin,
|
||||
# Han Zhu,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
from feature import TorchAudioFbank, TorchAudioFbankConfig
|
||||
from lhotse import CutSet, load_manifest_lazy, validate
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
DynamicBucketingSampler,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
)
|
||||
from lhotse.dataset.collation import collate_audio
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
BatchIO,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed, ifnone
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
SAMPLING_RATE = 24000
|
||||
|
||||
|
||||
class TtsDataModule:
|
||||
"""
|
||||
DataModule for tts experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="TTS data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank_emilia"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=100,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = SAMPLING_RATE
|
||||
config = TorchAudioFbankConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
n_mels=100,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
)
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = SAMPLING_RATE
|
||||
config = TorchAudioFbankConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
n_mels=100,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
)
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
dev_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create valid dataloader")
|
||||
dev_dl = DataLoader(
|
||||
validate,
|
||||
sampler=dev_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return dev_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.info("About to create test dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = SAMPLING_RATE
|
||||
config = TorchAudioFbankConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
n_mels=100,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
)
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(TorchAudioFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_audio=True,
|
||||
)
|
||||
else:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_audio=True,
|
||||
)
|
||||
test_sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=test_sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_emilia_EN_cuts(self) -> CutSet:
|
||||
logging.info("About to get train the EN subset")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_EN.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def train_emilia_ZH_cuts(self) -> CutSet:
|
||||
logging.info("About to get train the ZH subset")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_ZH.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def dev_emilia_EN_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev the EN subset")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "emilia_cuts_EN-dev.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_emilia_ZH_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev the ZH subset")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "emilia_cuts_ZH-dev.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_libritts_cuts(self) -> CutSet:
|
||||
logging.info(
|
||||
"About to get the shuffled train-clean-100, \
|
||||
train-clean-360 and train-other-500 cuts"
|
||||
)
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_libritts_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
|
||||
class SpeechSynthesisDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
The PyTorch Dataset for the speech synthesis task.
|
||||
Each item in this dataset is a dict of:
|
||||
|
||||
.. code-block::
|
||||
|
||||
{
|
||||
'audio': (B x NumSamples) float tensor
|
||||
'features': (B x NumFrames x NumFeatures) float tensor
|
||||
'audio_lens': (B, ) int tensor
|
||||
'features_lens': (B, ) int tensor
|
||||
'text': List[str] of len B # when return_text=True
|
||||
'tokens': List[List[str]] # when return_tokens=True
|
||||
'speakers': List[str] of len B # when return_spk_ids=True
|
||||
'cut': List of Cuts # when return_cuts=True
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
|
||||
feature_input_strategy: BatchIO = PrecomputedFeatures(),
|
||||
feature_transforms: Union[Sequence[Callable], Callable] = None,
|
||||
return_text: bool = True,
|
||||
return_tokens: bool = False,
|
||||
return_spk_ids: bool = False,
|
||||
return_cuts: bool = False,
|
||||
return_audio: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.cut_transforms = ifnone(cut_transforms, [])
|
||||
self.feature_input_strategy = feature_input_strategy
|
||||
|
||||
self.return_text = return_text
|
||||
self.return_tokens = return_tokens
|
||||
self.return_spk_ids = return_spk_ids
|
||||
self.return_cuts = return_cuts
|
||||
self.return_audio = return_audio
|
||||
|
||||
if feature_transforms is None:
|
||||
feature_transforms = []
|
||||
elif not isinstance(feature_transforms, Sequence):
|
||||
feature_transforms = [feature_transforms]
|
||||
|
||||
assert all(
|
||||
isinstance(transform, Callable) for transform in feature_transforms
|
||||
), "Feature transforms must be Callable"
|
||||
self.feature_transforms = feature_transforms
|
||||
|
||||
def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
|
||||
validate_for_tts(cuts)
|
||||
|
||||
for transform in self.cut_transforms:
|
||||
cuts = transform(cuts)
|
||||
|
||||
features, features_lens = self.feature_input_strategy(cuts)
|
||||
|
||||
for transform in self.feature_transforms:
|
||||
features = transform(features)
|
||||
|
||||
batch = {
|
||||
"features": features,
|
||||
"features_lens": features_lens,
|
||||
}
|
||||
|
||||
if self.return_audio:
|
||||
audio, audio_lens = collate_audio(cuts)
|
||||
batch["audio"] = audio
|
||||
batch["audio_lens"] = audio_lens
|
||||
|
||||
if self.return_text:
|
||||
# use normalized text
|
||||
text = [cut.supervisions[0].normalized_text for cut in cuts]
|
||||
batch["text"] = text
|
||||
|
||||
if self.return_tokens:
|
||||
tokens = [cut.tokens for cut in cuts]
|
||||
batch["tokens"] = tokens
|
||||
|
||||
if self.return_spk_ids:
|
||||
batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
|
||||
|
||||
if self.return_cuts:
|
||||
batch["cut"] = [cut for cut in cuts]
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def validate_for_tts(cuts: CutSet) -> None:
|
||||
validate(cuts)
|
||||
for cut in cuts:
|
||||
assert (
|
||||
len(cut.supervisions) == 1
|
||||
), "Only the Cuts with single supervision are supported."
|
@ -1,219 +0,0 @@
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
|
||||
class AttributeDict(dict):
|
||||
def __getattr__(self, key):
|
||||
if key in self:
|
||||
return self[key]
|
||||
raise AttributeError(f"No such attribute '{key}'")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
def __delattr__(self, key):
|
||||
if key in self:
|
||||
del self[key]
|
||||
return
|
||||
raise AttributeError(f"No such attribute '{key}'")
|
||||
|
||||
|
||||
def prepare_input(
|
||||
params: AttributeDict,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
tokenizer: Optional[Any] = None,
|
||||
return_tokens: bool = False,
|
||||
return_feature: bool = False,
|
||||
return_audio: bool = False,
|
||||
return_prompt: bool = False,
|
||||
):
|
||||
"""
|
||||
Parse the features and targets of the current batch.
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
sp:
|
||||
Used to convert text to bpe tokens.
|
||||
device:
|
||||
The device of Tensor.
|
||||
"""
|
||||
return_list = []
|
||||
|
||||
if return_tokens:
|
||||
assert tokenizer is not None
|
||||
|
||||
if params.token_type == "phone":
|
||||
tokens = tokenizer.tokens_to_token_ids(batch["tokens"])
|
||||
else:
|
||||
tokens = tokenizer.texts_to_token_ids(batch["text"])
|
||||
return_list += [tokens]
|
||||
|
||||
if return_feature:
|
||||
features = batch["features"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
return_list += [features * params.feat_scale, features_lens]
|
||||
|
||||
if return_audio:
|
||||
return_list += [batch["audio"], batch["audio_lens"]]
|
||||
|
||||
if return_prompt:
|
||||
if return_tokens:
|
||||
if params.token_type == "phone":
|
||||
prompt_tokens = tokenizer.tokens_to_token_ids(batch["prompt"]["tokens"])
|
||||
else:
|
||||
prompt_tokens = tokenizer.texts_to_token_ids(batch["prompt"]["text"])
|
||||
return_list += [prompt_tokens]
|
||||
if return_feature:
|
||||
prompt_features = batch["prompt"]["features"].to(device)
|
||||
prompt_features_lens = batch["prompt"]["features_lens"].to(device)
|
||||
return_list += [prompt_features * params.feat_scale, prompt_features_lens]
|
||||
if return_audio:
|
||||
return_list += [batch["prompt"]["audio"], batch["prompt"]["audio_lens"]]
|
||||
|
||||
return return_list
|
||||
|
||||
|
||||
def prepare_avg_tokens_durations(features_lens, tokens_lens):
|
||||
tokens_durations = []
|
||||
for i in range(len(features_lens)):
|
||||
utt_duration = features_lens[i]
|
||||
avg_token_duration = utt_duration // tokens_lens[i]
|
||||
tokens_durations.append([avg_token_duration] * tokens_lens[i])
|
||||
return tokens_durations
|
||||
|
||||
|
||||
def pad_labels(y: List[List[int]], pad_id: int, device: torch.device):
|
||||
"""
|
||||
Pad the transcripts to the same length with zeros.
|
||||
|
||||
Args:
|
||||
y: the transcripts, which is a list of a list
|
||||
|
||||
Returns:
|
||||
Return a Tensor of padded transcripts.
|
||||
"""
|
||||
y = [l + [pad_id] for l in y]
|
||||
length = max([len(l) for l in y])
|
||||
y = [l + [pad_id] * (length - len(l)) for l in y]
|
||||
return torch.tensor(y, dtype=torch.int64, device=device)
|
||||
|
||||
|
||||
def get_tokens_index(durations: List[List[int]], num_frames: int) -> torch.Tensor:
|
||||
"""
|
||||
Gets position in the transcript for each frame, i.e. the position
|
||||
in the symbol-sequence to look up.
|
||||
|
||||
Args:
|
||||
durations:
|
||||
Duration of each token in transcripts.
|
||||
num_frames:
|
||||
The maximum frame length of the current batch.
|
||||
|
||||
Returns:
|
||||
Return a Tensor of shape (batch_size, num_frames)
|
||||
"""
|
||||
durations = [x + [num_frames - sum(x)] for x in durations]
|
||||
batch_size = len(durations)
|
||||
ans = torch.zeros(batch_size, num_frames, dtype=torch.int64)
|
||||
for b in range(batch_size):
|
||||
this_dur = durations[b]
|
||||
cur_frame = 0
|
||||
for i, d in enumerate(this_dur):
|
||||
ans[b, cur_frame : cur_frame + d] = i
|
||||
cur_frame += d
|
||||
assert cur_frame == num_frames, (cur_frame, num_frames)
|
||||
return ans
|
||||
|
||||
|
||||
def to_int_tuple(s: str):
|
||||
return tuple(map(int, s.split(",")))
|
||||
|
||||
|
||||
def get_adjusted_batch_count(params: AttributeDict) -> float:
|
||||
# returns the number of batches we would have used so far if we had used the reference
|
||||
# duration. This is for purposes of set_batch_count().
|
||||
return (
|
||||
params.batch_idx_train
|
||||
* (params.max_duration * params.world_size)
|
||||
/ params.ref_duration
|
||||
)
|
||||
|
||||
|
||||
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
model = model.module
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "batch_count"):
|
||||
module.batch_count = batch_count
|
||||
if hasattr(module, "name"):
|
||||
module.name = name
|
||||
|
||||
|
||||
def condition_time_mask(
|
||||
features_lens: torch.Tensor, mask_percent: Tuple[float, float], max_len: int = 0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply Time masking.
|
||||
Args:
|
||||
features_lens:
|
||||
input tensor of shape ``(B)``
|
||||
mask_size:
|
||||
the width size for masking.
|
||||
max_len:
|
||||
the maximum length of the mask.
|
||||
Returns:
|
||||
Return a 2-D bool tensor (B, T), where masked positions
|
||||
are filled with `True` and non-masked positions are
|
||||
filled with `False`.
|
||||
"""
|
||||
mask_size = (
|
||||
torch.zeros_like(features_lens, dtype=torch.float32).uniform_(*mask_percent)
|
||||
* features_lens
|
||||
).to(torch.int64)
|
||||
mask_starts = (
|
||||
torch.rand_like(mask_size, dtype=torch.float32) * (features_lens - mask_size)
|
||||
).to(torch.int64)
|
||||
mask_ends = mask_starts + mask_size
|
||||
max_len = max(max_len, features_lens.max())
|
||||
seq_range = torch.arange(0, max_len, device=features_lens.device)
|
||||
mask = (seq_range[None, :] >= mask_starts[:, None]) & (
|
||||
seq_range[None, :] < mask_ends[:, None]
|
||||
)
|
||||
return mask
|
||||
|
||||
|
||||
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
lengths:
|
||||
A 1-D tensor containing sentence lengths.
|
||||
max_len:
|
||||
The length of masks.
|
||||
Returns:
|
||||
Return a 2-D bool tensor, where masked positions
|
||||
are filled with `True` and non-masked positions are
|
||||
filled with `False`.
|
||||
|
||||
>>> lengths = torch.tensor([1, 3, 2, 5])
|
||||
>>> make_pad_mask(lengths)
|
||||
tensor([[False, True, True, True, True],
|
||||
[False, False, False, True, True],
|
||||
[False, False, True, True, True],
|
||||
[False, False, False, False, False]])
|
||||
"""
|
||||
assert lengths.ndim == 1, lengths.ndim
|
||||
max_len = max(max_len, lengths.max())
|
||||
n = lengths.size(0)
|
||||
seq_range = torch.arange(0, max_len, device=lengths.device)
|
||||
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
||||
|
||||
return expaned_lengths >= lengths.unsqueeze(-1)
|
File diff suppressed because it is too large
Load Diff
@ -1,645 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script generates speech with our pre-trained ZipVoice or
|
||||
ZipVoice-Distill models. Required models will be automatically
|
||||
downloaded from HuggingFace.
|
||||
|
||||
Usage:
|
||||
|
||||
Note: If you having trouble connecting to HuggingFace,
|
||||
try switching endpoint to mirror site:
|
||||
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
(1) Inference of a single sentence:
|
||||
|
||||
python3 zipvoice/zipvoice_infer.py \
|
||||
--model-name "zipvoice_distill" \
|
||||
--prompt-wav prompt.wav \
|
||||
--prompt-text "I am a prompt." \
|
||||
--text "I am a sentence." \
|
||||
--res-wav-path result.wav
|
||||
|
||||
(2) Inference of a list of sentences:
|
||||
python3 zipvoice/zipvoice_infer.py \
|
||||
--model-name "zipvoice-distill" \
|
||||
--test-list test.tsv \
|
||||
--res-dir results
|
||||
|
||||
`--model-name` can be `zipvoice` or `zipvoice_distill`,
|
||||
which are the models before and after distillation, respectively.
|
||||
|
||||
Each line of `test.tsv` is in the format of
|
||||
`{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from feature import TorchAudioFbank, TorchAudioFbankConfig
|
||||
from huggingface_hub import hf_hub_download
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import get_distill_model, get_model
|
||||
from tokenizer import TokenizerEmilia
|
||||
from utils import AttributeDict
|
||||
from vocos import Vocos
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="zipvoice_distill",
|
||||
choices=["zipvoice", "zipvoice_distill"],
|
||||
help="The model used for inference",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--test-list",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The list of prompt speech, prompt_transcription, "
|
||||
"and text to synthesizein the format of "
|
||||
"'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt-wav",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt wav to mimic",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt-text",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The transcription of the prompt wav",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The text to synthesize",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--res-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="""
|
||||
Path name of the generated wavs dir,
|
||||
used when test-list is not None
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--res-wav-path",
|
||||
type=str,
|
||||
default="result.wav",
|
||||
help="""
|
||||
Path name of the generated wav path,
|
||||
used when test-list is None
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--guidance-scale",
|
||||
type=float,
|
||||
default=None,
|
||||
help="The scale of classifier-free guidance during inference.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-step",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The number of sampling steps.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feat-scale",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="The scale factor of fbank feature",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--speed",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Control speech speed, 1.0 means normal, >1.0 means speed up",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--t-shift",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Shift t to smaller ones if t_shift < 1.0",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target-rms",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Target speech normalization rms value",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=666,
|
||||
help="Random seed",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--fm-decoder-downsampling-factor",
|
||||
type=str,
|
||||
default="1,2,4,2,1",
|
||||
help="Downsampling factor for each stack of encoder layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fm-decoder-num-layers",
|
||||
type=str,
|
||||
default="2,2,4,4,4",
|
||||
help="Number of zipformer encoder layers per stack, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fm-decoder-cnn-module-kernel",
|
||||
type=str,
|
||||
default="31,15,7,15,31",
|
||||
help="Sizes of convolutional kernels in convolution modules "
|
||||
"in each encoder stack: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fm-decoder-feedforward-dim",
|
||||
type=int,
|
||||
default=1536,
|
||||
help="Feedforward dimension of the zipformer encoder layers, "
|
||||
"per stack, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fm-decoder-num-heads",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of attention heads in the zipformer encoder layers: "
|
||||
"a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fm-decoder-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Embedding dimension in encoder stacks: a single int "
|
||||
"or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text-encoder-downsampling-factor",
|
||||
type=str,
|
||||
default="1",
|
||||
help="Downsampling factor for each stack of encoder layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text-encoder-num-layers",
|
||||
type=str,
|
||||
default="4",
|
||||
help="Number of zipformer encoder layers per stack, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text-encoder-feedforward-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Feedforward dimension of the zipformer encoder layers, "
|
||||
"per stack, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text-encoder-cnn-module-kernel",
|
||||
type=str,
|
||||
default="9",
|
||||
help="Sizes of convolutional kernels in convolution modules in "
|
||||
"each encoder stack: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text-encoder-num-heads",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of attention heads in the zipformer encoder layers: "
|
||||
"a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text-encoder-dim",
|
||||
type=int,
|
||||
default=192,
|
||||
help="Embedding dimension in encoder stacks: "
|
||||
"a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--query-head-dim",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Query/key dimension per head in encoder stacks: "
|
||||
"a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--value-head-dim",
|
||||
type=int,
|
||||
default=12,
|
||||
help="Value dimension per head in encoder stacks: "
|
||||
"a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pos-head-dim",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Positional-encoding dimension per head in encoder stacks: "
|
||||
"a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pos-dim",
|
||||
type=int,
|
||||
default=48,
|
||||
help="Positional-encoding embedding dimension",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--time-embed-dim",
|
||||
type=int,
|
||||
default=192,
|
||||
help="Embedding dimension of timestamps embedding.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text-embed-dim",
|
||||
type=int,
|
||||
default=192,
|
||||
help="Embedding dimension of text embedding.",
|
||||
)
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"sampling_rate": 24000,
|
||||
"frame_shift_ms": 256 / 24000 * 1000,
|
||||
"feat_dim": 100,
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def get_vocoder():
|
||||
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
|
||||
return vocoder
|
||||
|
||||
|
||||
def generate_sentence(
|
||||
save_path: str,
|
||||
prompt_text: str,
|
||||
prompt_wav: str,
|
||||
text: str,
|
||||
model: nn.Module,
|
||||
vocoder: nn.Module,
|
||||
tokenizer: TokenizerEmilia,
|
||||
feature_extractor: TorchAudioFbank,
|
||||
device: torch.device,
|
||||
num_step: int = 16,
|
||||
guidance_scale: float = 1.0,
|
||||
speed: float = 1.0,
|
||||
t_shift: float = 0.5,
|
||||
target_rms: float = 0.1,
|
||||
feat_scale: float = 0.1,
|
||||
sampling_rate: int = 24000,
|
||||
):
|
||||
"""
|
||||
Generate waveform of a text based on a given prompt
|
||||
waveform and its transcription.
|
||||
|
||||
Args:
|
||||
save_path (str): Path to save the generated wav.
|
||||
prompt_text (str): Transcription of the prompt wav.
|
||||
prompt_wav (str): Path to the prompt wav file.
|
||||
text (str): Text to be synthesized into a waveform.
|
||||
model (nn.Module): The model used for generation.
|
||||
vocoder (nn.Module): The vocoder used to convert features to waveforms.
|
||||
tokenizer (TokenizerEmilia): The tokenizer used to convert text to tokens.
|
||||
feature_extractor (TorchAudioFbank): The feature extractor used to
|
||||
extract acoustic features.
|
||||
device (torch.device): The device on which computations are performed.
|
||||
num_step (int, optional): Number of steps for decoding. Defaults to 16.
|
||||
guidance_scale (float, optional): Scale for classifier-free guidance.
|
||||
Defaults to 1.0.
|
||||
speed (float, optional): Speed control. Defaults to 1.0.
|
||||
t_shift (float, optional): Time shift. Defaults to 0.5.
|
||||
target_rms (float, optional): Target RMS for waveform normalization.
|
||||
Defaults to 0.1.
|
||||
feat_scale (float, optional): Scale for features.
|
||||
Defaults to 0.1.
|
||||
sampling_rate (int, optional): Sampling rate for the waveform.
|
||||
Defaults to 24000.
|
||||
Returns:
|
||||
metrics (dict): Dictionary containing time and real-time
|
||||
factor metrics for processing.
|
||||
"""
|
||||
# Convert text to tokens
|
||||
tokens = tokenizer.texts_to_token_ids([text])
|
||||
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
|
||||
|
||||
# Load and preprocess prompt wav
|
||||
prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
|
||||
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
|
||||
if prompt_rms < target_rms:
|
||||
prompt_wav = prompt_wav * target_rms / prompt_rms
|
||||
|
||||
if prompt_sampling_rate != sampling_rate:
|
||||
resampler = torchaudio.transforms.Resample(
|
||||
orig_freq=prompt_sampling_rate, new_freq=sampling_rate
|
||||
)
|
||||
prompt_wav = resampler(prompt_wav)
|
||||
|
||||
# Extract features from prompt wav
|
||||
prompt_features = feature_extractor.extract(
|
||||
prompt_wav, sampling_rate=sampling_rate
|
||||
).to(device)
|
||||
prompt_features = prompt_features.unsqueeze(0) * feat_scale
|
||||
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
|
||||
|
||||
# Start timing
|
||||
start_t = dt.datetime.now()
|
||||
|
||||
# Generate features
|
||||
(
|
||||
pred_features,
|
||||
pred_features_lens,
|
||||
pred_prompt_features,
|
||||
pred_prompt_features_lens,
|
||||
) = model.sample(
|
||||
tokens=tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_features=prompt_features,
|
||||
prompt_features_lens=prompt_features_lens,
|
||||
speed=speed,
|
||||
t_shift=t_shift,
|
||||
duration="predict",
|
||||
num_step=num_step,
|
||||
guidance_scale=guidance_scale,
|
||||
)
|
||||
|
||||
# Postprocess predicted features
|
||||
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
|
||||
|
||||
# Start vocoder processing
|
||||
start_vocoder_t = dt.datetime.now()
|
||||
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
|
||||
|
||||
# Calculate processing times and real-time factors
|
||||
t = (dt.datetime.now() - start_t).total_seconds()
|
||||
t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
|
||||
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
|
||||
wav_seconds = wav.shape[-1] / sampling_rate
|
||||
rtf = t / wav_seconds
|
||||
rtf_no_vocoder = t_no_vocoder / wav_seconds
|
||||
rtf_vocoder = t_vocoder / wav_seconds
|
||||
metrics = {
|
||||
"t": t,
|
||||
"t_no_vocoder": t_no_vocoder,
|
||||
"t_vocoder": t_vocoder,
|
||||
"wav_seconds": wav_seconds,
|
||||
"rtf": rtf,
|
||||
"rtf_no_vocoder": rtf_no_vocoder,
|
||||
"rtf_vocoder": rtf_vocoder,
|
||||
}
|
||||
|
||||
# Adjust wav volume if necessary
|
||||
if prompt_rms < target_rms:
|
||||
wav = wav * prompt_rms / target_rms
|
||||
torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def generate(
|
||||
res_dir: str,
|
||||
test_list: str,
|
||||
model: nn.Module,
|
||||
vocoder: nn.Module,
|
||||
tokenizer: TokenizerEmilia,
|
||||
feature_extractor: TorchAudioFbank,
|
||||
device: torch.device,
|
||||
num_step: int = 16,
|
||||
guidance_scale: float = 1.0,
|
||||
speed: float = 1.0,
|
||||
t_shift: float = 0.5,
|
||||
target_rms: float = 0.1,
|
||||
feat_scale: float = 0.1,
|
||||
sampling_rate: int = 24000,
|
||||
):
|
||||
total_t = []
|
||||
total_t_no_vocoder = []
|
||||
total_t_vocoder = []
|
||||
total_wav_seconds = []
|
||||
|
||||
with open(test_list, "r") as fr:
|
||||
lines = fr.readlines()
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
|
||||
save_path = f"{res_dir}/{wav_name}.wav"
|
||||
metrics = generate_sentence(
|
||||
save_path=save_path,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav=prompt_wav,
|
||||
text=text,
|
||||
model=model,
|
||||
vocoder=vocoder,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
device=device,
|
||||
num_step=num_step,
|
||||
guidance_scale=guidance_scale,
|
||||
speed=speed,
|
||||
t_shift=t_shift,
|
||||
target_rms=target_rms,
|
||||
feat_scale=feat_scale,
|
||||
sampling_rate=sampling_rate,
|
||||
)
|
||||
print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
|
||||
total_t.append(metrics["t"])
|
||||
total_t_no_vocoder.append(metrics["t_no_vocoder"])
|
||||
total_t_vocoder.append(metrics["t_vocoder"])
|
||||
total_wav_seconds.append(metrics["wav_seconds"])
|
||||
|
||||
print(f"Average RTF: {np.sum(total_t)/np.sum(total_wav_seconds):.4f}")
|
||||
print(
|
||||
f"Average RTF w/o vocoder: "
|
||||
f"{np.sum(total_t_no_vocoder)/np.sum(total_wav_seconds):.4f}"
|
||||
)
|
||||
print(
|
||||
f"Average RTF vocoder: "
|
||||
f"{np.sum(total_t_vocoder)/np.sum(total_wav_seconds):.4f}"
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
model_defaults = {
|
||||
"zipvoice": {
|
||||
"num_step": 16,
|
||||
"guidance_scale": 1.0,
|
||||
},
|
||||
"zipvoice_distill": {
|
||||
"num_step": 8,
|
||||
"guidance_scale": 3.0,
|
||||
},
|
||||
}
|
||||
|
||||
model_specific_defaults = model_defaults.get(params.model_name, {})
|
||||
|
||||
for param, value in model_specific_defaults.items():
|
||||
if getattr(params, param) == parser.get_default(param):
|
||||
setattr(params, param, value)
|
||||
print(f"Setting {param} to default value: {value}")
|
||||
|
||||
assert (params.test_list is not None) ^ (
|
||||
(params.prompt_wav and params.prompt_text and params.text) is not None
|
||||
), (
|
||||
"For inference, please provide prompts and text with either '--test-list'"
|
||||
" or '--prompt-wav, --prompt-text and --text'."
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
params.device = torch.device("cuda", 0)
|
||||
else:
|
||||
params.device = torch.device("cpu")
|
||||
|
||||
token_file = hf_hub_download("zhu-han/ZipVoice", filename="tokens_emilia.txt")
|
||||
|
||||
tokenizer = TokenizerEmilia(token_file)
|
||||
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.pad_id = tokenizer.pad_id
|
||||
fix_random_seed(params.seed)
|
||||
|
||||
if params.model_name == "zipvoice_distill":
|
||||
model = get_distill_model(params)
|
||||
model_ckpt = hf_hub_download(
|
||||
"zhu-han/ZipVoice", filename="exp_zipvoice_distill/model.safetensors"
|
||||
)
|
||||
else:
|
||||
model = get_model(params)
|
||||
model_ckpt = hf_hub_download(
|
||||
"zhu-han/ZipVoice", filename="exp_zipvoice/model.safetensors"
|
||||
)
|
||||
|
||||
safetensors.torch.load_model(model, model_ckpt)
|
||||
|
||||
model = model.to(params.device)
|
||||
model.eval()
|
||||
|
||||
vocoder = get_vocoder()
|
||||
vocoder = vocoder.to(params.device)
|
||||
vocoder.eval()
|
||||
|
||||
config = TorchAudioFbankConfig(
|
||||
sampling_rate=params.sampling_rate,
|
||||
n_mels=100,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
)
|
||||
feature_extractor = TorchAudioFbank(config)
|
||||
|
||||
if params.test_list:
|
||||
os.makedirs(params.res_dir, exist_ok=True)
|
||||
generate(
|
||||
res_dir=params.res_dir,
|
||||
test_list=params.test_list,
|
||||
model=model,
|
||||
vocoder=vocoder,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
device=params.device,
|
||||
num_step=params.num_step,
|
||||
guidance_scale=params.guidance_scale,
|
||||
speed=params.speed,
|
||||
t_shift=params.t_shift,
|
||||
target_rms=params.target_rms,
|
||||
feat_scale=params.feat_scale,
|
||||
sampling_rate=params.sampling_rate,
|
||||
)
|
||||
else:
|
||||
generate_sentence(
|
||||
save_path=params.res_wav_path,
|
||||
prompt_text=params.prompt_text,
|
||||
prompt_wav=params.prompt_wav,
|
||||
text=params.text,
|
||||
model=model,
|
||||
vocoder=vocoder,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
device=params.device,
|
||||
num_step=params.num_step,
|
||||
guidance_scale=params.guidance_scale,
|
||||
speed=params.speed,
|
||||
t_shift=params.t_shift,
|
||||
target_rms=params.target_rms,
|
||||
feat_scale=params.feat_scale,
|
||||
sampling_rate=params.sampling_rate,
|
||||
)
|
||||
print("Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user