mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
[zipvoice] Add requirements.txt and pinyin.txt, remove k2 from pretrained model inference. (#1965)
* Add requirements.txt and pinyin.txt needed by zipvoice * simplify the requirements for pretrained model inference
This commit is contained in:
parent
06539d2b9d
commit
762f965cf7
@ -39,15 +39,6 @@ source venv/bin/activate
|
|||||||
* Install the required packages:
|
* Install the required packages:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Install pytorch and k2.
|
|
||||||
# If you want to use different versions, please refer to https://k2-fsa.org/get-started/k2/ for details.
|
|
||||||
# For users in China mainland, please refer to https://k2-fsa.org/zh-CN/get-started/k2/
|
|
||||||
|
|
||||||
pip install torch==2.5.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
|
|
||||||
pip install k2==1.24.4.dev20250208+cuda12.1.torch2.5.1 -f https://k2-fsa.github.io/k2/cuda.html
|
|
||||||
|
|
||||||
# Install other dependencies.
|
|
||||||
pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
|
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -97,6 +88,16 @@ The following steps show how to train a model from scratch on Emilia and LibriTT
|
|||||||
### 0. Install dependencies for training
|
### 0. Install dependencies for training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Install pytorch and k2.
|
||||||
|
# If you want to use different versions, please refer to https://k2-fsa.org/get-started/k2/ for details.
|
||||||
|
# For users in China mainland, please refer to https://k2-fsa.org/zh-CN/get-started/k2/
|
||||||
|
|
||||||
|
# Note: Make sure you have installed the correct version of PyTorch and k2 that matches your CUDA version.
|
||||||
|
# For example, if want to use pytorch 2.5.1 and you are using CUDA 12.1, you can install PyTorch and k2 as follows:
|
||||||
|
|
||||||
|
pip install torch==2.5.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
pip install k2==1.24.4.dev20250208+cuda12.1.torch2.5.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||||
|
|
||||||
pip install -r ../../requirements.txt
|
pip install -r ../../requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -403,7 +404,7 @@ on three test sets, i.e., LibriSpeech-PC test-clean, Seed-TTS test-en and Seed-T
|
|||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@article{zhu-2025-zipvoice,
|
@article{zhu-2025-zipvoice,
|
||||||
title={ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching},
|
title={ZipVoice: Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching},
|
||||||
author={Han Zhu and Wei Kang and Zengwei Yao and Liyong Guo and Fangjun Kuang and Zhaoqing Li and Weiji Zhuang and Long Lin and Daniel Povey}
|
author={Han Zhu and Wei Kang and Zengwei Yao and Liyong Guo and Fangjun Kuang and Zhaoqing Li and Weiji Zhuang and Long Lin and Daniel Povey}
|
||||||
journal={arXiv preprint arXiv:2506.13053},
|
journal={arXiv preprint arXiv:2506.13053},
|
||||||
year={2025},
|
year={2025},
|
||||||
|
@ -19,11 +19,13 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from concurrent.futures import ProcessPoolExecutor as Pool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from concurrent.futures import ProcessPoolExecutor as Pool
|
|
||||||
|
|
||||||
|
import lhotse
|
||||||
import torch
|
import torch
|
||||||
|
from feature import TorchAudioFbank, TorchAudioFbankConfig
|
||||||
from lhotse import (
|
from lhotse import (
|
||||||
CutSet,
|
CutSet,
|
||||||
LilcomChunkyWriter,
|
LilcomChunkyWriter,
|
||||||
@ -31,9 +33,6 @@ from lhotse import (
|
|||||||
set_audio_duration_mismatch_tolerance,
|
set_audio_duration_mismatch_tolerance,
|
||||||
)
|
)
|
||||||
|
|
||||||
from feature import TorchAudioFbank, TorchAudioFbankConfig
|
|
||||||
import lhotse
|
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
# Do this outside of main() in case it needs to take effect
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
1550
egs/zipvoice/local/pinyin.txt
Normal file
1550
egs/zipvoice/local/pinyin.txt
Normal file
File diff suppressed because it is too large
Load Diff
@ -24,15 +24,14 @@ This file reads the texts in given manifest and save the cleaned new cuts.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import glob
|
import glob
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from concurrent.futures import ProcessPoolExecutor as Pool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from lhotse import CutSet, load_manifest_lazy
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
from concurrent.futures import ProcessPoolExecutor as Pool
|
|
||||||
|
|
||||||
from tokenizer import (
|
from tokenizer import (
|
||||||
is_alphabet,
|
is_alphabet,
|
||||||
is_chinese,
|
is_chinese,
|
||||||
|
17
egs/zipvoice/requirements.txt
Normal file
17
egs/zipvoice/requirements.txt
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
--find-links https://k2-fsa.github.io/icefall/piper_phonemize.html
|
||||||
|
|
||||||
|
torch
|
||||||
|
torchaudio
|
||||||
|
huggingface_hub
|
||||||
|
lhotse
|
||||||
|
safetensors
|
||||||
|
vocos
|
||||||
|
|
||||||
|
# Normalization
|
||||||
|
cn2an
|
||||||
|
inflect
|
||||||
|
|
||||||
|
# Tokenization
|
||||||
|
jieba
|
||||||
|
piper_phonemize
|
||||||
|
pypinyin
|
@ -18,9 +18,17 @@
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
|
import sys
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
try:
|
||||||
|
import k2
|
||||||
|
except Exception as ex:
|
||||||
|
logging.warning(
|
||||||
|
"k2 is not installed correctly. Swoosh functions will fallback to "
|
||||||
|
"pytorch implementation."
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -1398,7 +1406,11 @@ class SwooshLFunction(torch.autograd.Function):
|
|||||||
class SwooshL(torch.nn.Module):
|
class SwooshL(torch.nn.Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Return Swoosh-L activation."""
|
"""Return Swoosh-L activation."""
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
if (
|
||||||
|
torch.jit.is_scripting()
|
||||||
|
or torch.jit.is_tracing()
|
||||||
|
or "k2" not in sys.modules
|
||||||
|
):
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||||
if not x.requires_grad:
|
if not x.requires_grad:
|
||||||
@ -1472,7 +1484,11 @@ class SwooshRFunction(torch.autograd.Function):
|
|||||||
class SwooshR(torch.nn.Module):
|
class SwooshR(torch.nn.Module):
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Return Swoosh-R activation."""
|
"""Return Swoosh-R activation."""
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
if (
|
||||||
|
torch.jit.is_scripting()
|
||||||
|
or torch.jit.is_tracing()
|
||||||
|
or "k2" not in sys.modules
|
||||||
|
):
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
|
return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
|
||||||
if not x.requires_grad:
|
if not x.requires_grad:
|
||||||
@ -1636,7 +1652,11 @@ class ActivationDropoutAndLinear(torch.nn.Module):
|
|||||||
self.dropout_shared_dim = dropout_shared_dim
|
self.dropout_shared_dim = dropout_shared_dim
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
if (
|
||||||
|
torch.jit.is_scripting()
|
||||||
|
or torch.jit.is_tracing()
|
||||||
|
or "k2" not in sys.modules
|
||||||
|
):
|
||||||
if self.activation == "SwooshL":
|
if self.activation == "SwooshL":
|
||||||
x = SwooshLForward(x)
|
x = SwooshLForward(x)
|
||||||
elif self.activation == "SwooshR":
|
elif self.activation == "SwooshR":
|
||||||
|
@ -321,7 +321,8 @@ def tokenize_ZH(text: str) -> List[str]:
|
|||||||
if final != "":
|
if final != "":
|
||||||
phones.append(final)
|
phones.append(final)
|
||||||
return phones
|
return phones
|
||||||
except:
|
except Exception as ex:
|
||||||
|
logging.warning(f"Tokenize ZH failed: {ex}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@ -332,7 +333,8 @@ def tokenize_EN(text: str) -> List[str]:
|
|||||||
tokens = phonemize_espeak(text, "en-us")
|
tokens = phonemize_espeak(text, "en-us")
|
||||||
tokens = reduce(lambda x, y: x + y, tokens)
|
tokens = reduce(lambda x, y: x + y, tokens)
|
||||||
return tokens
|
return tokens
|
||||||
except:
|
except Exception as ex:
|
||||||
|
logging.warning(f"Tokenize EN failed: {ex}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@ -561,7 +563,7 @@ class TokenizerLibriTTS(object):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
text = "我们是5年小米人,是吗? Yes I think so! mr king, 5 years, from 2019 to 2024. 霍...啦啦啦超过90%的人咯...?!9204"
|
text = "我们是5年小米人,是吗? Yes I think so! mr king, 5 years, from 2019 to 2024. 霍...啦啦啦超过90%的人咯...?!9204"
|
||||||
tokenizer = Tokenizer()
|
tokenizer = TokenizerEmilia()
|
||||||
tokens = tokenizer.texts_to_tokens([text])
|
tokens = tokenizer.texts_to_tokens([text])
|
||||||
print(f"tokens : {tokens}")
|
print(f"tokens : {tokens}")
|
||||||
tokens2 = "|".join(tokens[0])
|
tokens2 = "|".join(tokens[0])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user