[zipvoice] Add requirements.txt and pinyin.txt, remove k2 from pretrained model inference. (#1965)

* Add requirements.txt and pinyin.txt needed by zipvoice

* simplify the requirements for pretrained model inference
This commit is contained in:
Wei Kang 2025-06-18 18:38:46 +08:00 committed by GitHub
parent 06539d2b9d
commit 762f965cf7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1612 additions and 24 deletions

View File

@ -39,15 +39,6 @@ source venv/bin/activate
* Install the required packages:
```bash
# Install pytorch and k2.
# If you want to use different versions, please refer to https://k2-fsa.org/get-started/k2/ for details.
# For users in China mainland, please refer to https://k2-fsa.org/zh-CN/get-started/k2/
pip install torch==2.5.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
pip install k2==1.24.4.dev20250208+cuda12.1.torch2.5.1 -f https://k2-fsa.github.io/k2/cuda.html
# Install other dependencies.
pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
pip install -r requirements.txt
```
@ -97,6 +88,16 @@ The following steps show how to train a model from scratch on Emilia and LibriTT
### 0. Install dependencies for training
```bash
# Install pytorch and k2.
# If you want to use different versions, please refer to https://k2-fsa.org/get-started/k2/ for details.
# For users in China mainland, please refer to https://k2-fsa.org/zh-CN/get-started/k2/
# Note: Make sure you have installed the correct version of PyTorch and k2 that matches your CUDA version.
# For example, if want to use pytorch 2.5.1 and you are using CUDA 12.1, you can install PyTorch and k2 as follows:
pip install torch==2.5.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
pip install k2==1.24.4.dev20250208+cuda12.1.torch2.5.1 -f https://k2-fsa.github.io/k2/cuda.html
pip install -r ../../requirements.txt
```
@ -403,7 +404,7 @@ on three test sets, i.e., LibriSpeech-PC test-clean, Seed-TTS test-en and Seed-T
```bibtex
@article{zhu-2025-zipvoice,
title={ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching},
title={ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching},
author={Han Zhu and Wei Kang and Zengwei Yao and Liyong Guo and Fangjun Kuang and Zhaoqing Li and Weiji Zhuang and Long Lin and Daniel Povey}
journal={arXiv preprint arXiv:2506.13053},
year={2025},

View File

@ -19,11 +19,13 @@
import argparse
import logging
import os
from concurrent.futures import ProcessPoolExecutor as Pool
from pathlib import Path
from typing import Optional
from concurrent.futures import ProcessPoolExecutor as Pool
import lhotse
import torch
from feature import TorchAudioFbank, TorchAudioFbankConfig
from lhotse import (
CutSet,
LilcomChunkyWriter,
@ -31,9 +33,6 @@ from lhotse import (
set_audio_duration_mismatch_tolerance,
)
from feature import TorchAudioFbank, TorchAudioFbankConfig
import lhotse
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect

File diff suppressed because it is too large Load Diff

View File

@ -24,15 +24,14 @@ This file reads the texts in given manifest and save the cleaned new cuts.
"""
import argparse
import logging
import glob
import logging
import os
from concurrent.futures import ProcessPoolExecutor as Pool
from pathlib import Path
from typing import List
from lhotse import CutSet, load_manifest_lazy
from concurrent.futures import ProcessPoolExecutor as Pool
from tokenizer import (
is_alphabet,
is_chinese,

View File

@ -0,0 +1,17 @@
--find-links https://k2-fsa.github.io/icefall/piper_phonemize.html
torch
torchaudio
huggingface_hub
lhotse
safetensors
vocos
# Normalization
cn2an
inflect
# Tokenization
jieba
piper_phonemize
pypinyin

View File

@ -18,9 +18,17 @@
import logging
import math
import random
import sys
from typing import Optional, Tuple, Union
import k2
try:
import k2
except Exception as ex:
logging.warning(
"k2 is not installed correctly. Swoosh functions will fallback to "
"pytorch implementation."
)
import torch
import torch.nn as nn
from torch import Tensor
@ -1398,7 +1406,11 @@ class SwooshLFunction(torch.autograd.Function):
class SwooshL(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return Swoosh-L activation."""
if torch.jit.is_scripting() or torch.jit.is_tracing():
if (
torch.jit.is_scripting()
or torch.jit.is_tracing()
or "k2" not in sys.modules
):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
if not x.requires_grad:
@ -1472,7 +1484,11 @@ class SwooshRFunction(torch.autograd.Function):
class SwooshR(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return Swoosh-R activation."""
if torch.jit.is_scripting() or torch.jit.is_tracing():
if (
torch.jit.is_scripting()
or torch.jit.is_tracing()
or "k2" not in sys.modules
):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
if not x.requires_grad:
@ -1636,7 +1652,11 @@ class ActivationDropoutAndLinear(torch.nn.Module):
self.dropout_shared_dim = dropout_shared_dim
def forward(self, x: Tensor):
if torch.jit.is_scripting() or torch.jit.is_tracing():
if (
torch.jit.is_scripting()
or torch.jit.is_tracing()
or "k2" not in sys.modules
):
if self.activation == "SwooshL":
x = SwooshLForward(x)
elif self.activation == "SwooshR":

View File

@ -321,7 +321,8 @@ def tokenize_ZH(text: str) -> List[str]:
if final != "":
phones.append(final)
return phones
except:
except Exception as ex:
logging.warning(f"Tokenize ZH failed: {ex}")
return []
@ -332,7 +333,8 @@ def tokenize_EN(text: str) -> List[str]:
tokens = phonemize_espeak(text, "en-us")
tokens = reduce(lambda x, y: x + y, tokens)
return tokens
except:
except Exception as ex:
logging.warning(f"Tokenize EN failed: {ex}")
return []
@ -561,7 +563,7 @@ class TokenizerLibriTTS(object):
if __name__ == "__main__":
text = "我们是5年小米人,是吗? Yes I think so! mr king, 5 years, from 2019 to 2024. 霍...啦啦啦超过90%的人咯...?!9204"
tokenizer = Tokenizer()
tokenizer = TokenizerEmilia()
tokens = tokenizer.texts_to_tokens([text])
print(f"tokens : {tokens}")
tokens2 = "|".join(tokens[0])