mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-15 04:52:22 +00:00
update readme
This commit is contained in:
parent
d55a534af8
commit
e66f133bcf
@ -49,3 +49,54 @@ To inference, use:
|
|||||||
--epoch 400 \
|
--epoch 400 \
|
||||||
--tokens data/tokens.txt
|
--tokens data/tokens.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# [VALL-E](https://arxiv.org/abs/2301.02111)
|
||||||
|
|
||||||
|
./valle contains the code for training VALL-E TTS model.
|
||||||
|
|
||||||
|
Checkpoints and training logs can be found [here](https://huggingface.co/yuekai/vall-e_libritts). The demo of the model trained with libritts and [libritts-r](https://www.openslr.org/141/) is available [here](https://huggingface.co/spaces/yuekai/valle-libritts-demo).
|
||||||
|
|
||||||
|
Preparation:
|
||||||
|
|
||||||
|
```
|
||||||
|
bash prepare.sh --start-stage 4
|
||||||
|
```
|
||||||
|
|
||||||
|
The training command is given below:
|
||||||
|
|
||||||
|
```
|
||||||
|
world_size=8
|
||||||
|
exp_dir=exp/valle
|
||||||
|
|
||||||
|
## Train AR model
|
||||||
|
python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \
|
||||||
|
--num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \
|
||||||
|
--share-embedding true --norm-first true --add-prenet false \
|
||||||
|
--decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \
|
||||||
|
--base-lr 0.03 --warmup-steps 200 --average-period 0 \
|
||||||
|
--num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \
|
||||||
|
--exp-dir ${exp_dir} --world-size ${world_size}
|
||||||
|
|
||||||
|
## Train NAR model
|
||||||
|
# cd ${exp_dir}
|
||||||
|
# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1
|
||||||
|
# cd -
|
||||||
|
python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \
|
||||||
|
--num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \
|
||||||
|
--share-embedding true --norm-first true --add-prenet false \
|
||||||
|
--decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \
|
||||||
|
--base-lr 0.03 --warmup-steps 200 --average-period 0 \
|
||||||
|
--num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \
|
||||||
|
--exp-dir ${exp_dir} --world-size ${world_size}
|
||||||
|
```
|
||||||
|
|
||||||
|
To inference, use:
|
||||||
|
```
|
||||||
|
huggingface-cli login
|
||||||
|
huggingface-cli download --local-dir ${exp_dir} yuekai/vall-e_libritts
|
||||||
|
top_p=1.0
|
||||||
|
python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \
|
||||||
|
--top-k -1 --temperature 1.0 \
|
||||||
|
--text ./libritts.txt \
|
||||||
|
--checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt --top-p ${top_p}
|
||||||
|
```
|
||||||
|
@ -1,14 +1,6 @@
|
|||||||
# Introduction
|
# Introduction
|
||||||
|
|
||||||
LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members.
|
[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset.
|
||||||
The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus.
|
|
||||||
The main differences from the LibriSpeech corpus are listed below:
|
|
||||||
1. The audio files are at 24kHz sampling rate.
|
|
||||||
2. The speech is split at sentence breaks.
|
|
||||||
3. Both original and normalized texts are included.
|
|
||||||
4. Contextual information (e.g., neighbouring sentences) can be extracted.
|
|
||||||
5. Utterances with significant background noise are excluded.
|
|
||||||
For more information, refer to the paper "LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech", Heiga Zen, Viet Dang, Rob Clark, Yu Zhang, Ron J. Weiss, Ye Jia, Zhifeng Chen, and Yonghui Wu, arXiv, 2019. If you use the LibriTTS corpus in your work, please cite this paper where it was introduced.
|
|
||||||
|
|
||||||
> [!CAUTION]
|
> [!CAUTION]
|
||||||
> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS).
|
> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS).
|
||||||
@ -24,28 +16,57 @@ For more information, refer to the paper "LibriTTS: A Corpus Derived from LibriS
|
|||||||
> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties.
|
> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties.
|
||||||
|
|
||||||
|
|
||||||
# VITS
|
# [VALL-E](https://arxiv.org/abs/2301.02111)
|
||||||
|
|
||||||
This recipe provides a VITS model trained on the LibriTTS dataset.
|
./valle contains the code for training VALL-E TTS model.
|
||||||
|
|
||||||
Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-libritts-vits-2024-10-30).
|
Checkpoints and training logs can be found [here](https://huggingface.co/yuekai/vall-e_wenetspeech4tts). The demo of the model trained with Wenetspeech4TTS Premium (945 hours) is available [here](https://huggingface.co/spaces/yuekai/valle_wenetspeech4tts_demo).
|
||||||
|
|
||||||
|
Preparation:
|
||||||
|
|
||||||
|
```
|
||||||
|
bash prepare.sh
|
||||||
|
```
|
||||||
|
|
||||||
The training command is given below:
|
The training command is given below:
|
||||||
|
|
||||||
```
|
```
|
||||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
world_size=8
|
||||||
./vits/train.py \
|
exp_dir=exp/valle
|
||||||
--world-size 4 \
|
|
||||||
--num-epochs 400 \
|
## Train AR model
|
||||||
--start-epoch 1 \
|
python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \
|
||||||
--use-fp16 1 \
|
--num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \
|
||||||
--exp-dir vits/exp \
|
--share-embedding true --norm-first true --add-prenet false \
|
||||||
--max-duration 500
|
--decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \
|
||||||
|
--base-lr 0.03 --warmup-steps 200 --average-period 0 \
|
||||||
|
--num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \
|
||||||
|
--exp-dir ${exp_dir} --world-size ${world_size}
|
||||||
|
|
||||||
|
## Train NAR model
|
||||||
|
# cd ${exp_dir}
|
||||||
|
# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1
|
||||||
|
# cd -
|
||||||
|
python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \
|
||||||
|
--num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \
|
||||||
|
--share-embedding true --norm-first true --add-prenet false \
|
||||||
|
--decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \
|
||||||
|
--base-lr 0.03 --warmup-steps 200 --average-period 0 \
|
||||||
|
--num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \
|
||||||
|
--exp-dir ${exp_dir} --world-size ${world_size}
|
||||||
```
|
```
|
||||||
|
|
||||||
To inference, use:
|
To inference, use:
|
||||||
```
|
```
|
||||||
./vits/infer.py \
|
huggingface-cli login
|
||||||
--exp-dir vits/exp \
|
huggingface-cli download --local-dir ${exp_dir} yuekai/vall-e_wenetspeech4tts
|
||||||
--epoch 400 \
|
top_p=1.0
|
||||||
--tokens data/tokens.txt
|
python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \
|
||||||
|
--top-k -1 --temperature 1.0 \
|
||||||
|
--text ./aishell3.txt \
|
||||||
|
--checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt \
|
||||||
|
--text-extractor pypinyin_initials_finals --top-p ${top_p}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# Credits
|
||||||
|
- [vall-e](https://github.com/lifeiteng/vall-e)
|
||||||
|
@ -16,8 +16,12 @@
|
|||||||
Phonemize Text and EnCodec Audio.
|
Phonemize Text and EnCodec Audio.
|
||||||
|
|
||||||
Usage example:
|
Usage example:
|
||||||
python3 bin/tokenizer.py \
|
python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \
|
||||||
--src_dir ./data/manifests --output_dir ./data/tokenized
|
--text-extractor ${text_extractor} \
|
||||||
|
--audio-extractor ${audio_extractor} \
|
||||||
|
--batch-duration 2500 --prefix "wenetspeech4tts" \
|
||||||
|
--src-dir "data/manifests" --split 100 \
|
||||||
|
--output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100"
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
@ -523,7 +527,7 @@ def main():
|
|||||||
"wenetspeech4tts",
|
"wenetspeech4tts",
|
||||||
]:
|
]:
|
||||||
part = part.resample(24000)
|
part = part.resample(24000)
|
||||||
assert args.prefix_lower() in [
|
assert args.prefix.lower() in [
|
||||||
"ljspeech",
|
"ljspeech",
|
||||||
"aishell",
|
"aishell",
|
||||||
"baker",
|
"baker",
|
||||||
@ -557,36 +561,26 @@ def main():
|
|||||||
# TextTokenizer
|
# TextTokenizer
|
||||||
if args.text_extractor:
|
if args.text_extractor:
|
||||||
for c in tqdm(part):
|
for c in tqdm(part):
|
||||||
if (
|
if args.prefix == "ljspeech":
|
||||||
args.prefix == "baker"
|
text = c.supervisions[0].custom["normalized_text"]
|
||||||
and args.text_extractor == "labeled_pinyin"
|
text = text.replace(""", '"').replace(""", '"')
|
||||||
):
|
phonemes = tokenize_text(text_tokenizer, text=text)
|
||||||
phonemes = c.supervisions[0].custom["tokens"]["text"]
|
elif args.prefix in [
|
||||||
unique_symbols.update(phonemes)
|
"aishell",
|
||||||
|
"aishell2",
|
||||||
|
"wenetspeech4tts",
|
||||||
|
"libritts",
|
||||||
|
"libritts-r",
|
||||||
|
]:
|
||||||
|
phonemes = tokenize_text(
|
||||||
|
text_tokenizer, text=c.supervisions[0].text
|
||||||
|
)
|
||||||
|
if c.supervisions[0].custom is None:
|
||||||
|
c.supervisions[0].custom = {}
|
||||||
|
c.supervisions[0].normalized_text = c.supervisions[0].text
|
||||||
else:
|
else:
|
||||||
if args.prefix == "ljspeech":
|
raise NotImplementedError(f"{args.prefix}")
|
||||||
text = c.supervisions[0].custom["normalized_text"]
|
unique_symbols.update(phonemes)
|
||||||
text = text.replace(""", '"').replace(""", '"')
|
|
||||||
phonemes = tokenize_text(text_tokenizer, text=text)
|
|
||||||
elif args.prefix in [
|
|
||||||
"aishell",
|
|
||||||
"aishell2",
|
|
||||||
"wenetspeech4tts",
|
|
||||||
"libritts",
|
|
||||||
"libritts-r",
|
|
||||||
]:
|
|
||||||
phonemes = tokenize_text(
|
|
||||||
text_tokenizer, text=c.supervisions[0].text
|
|
||||||
)
|
|
||||||
if c.supervisions[0].custom is None:
|
|
||||||
c.supervisions[0].custom = {}
|
|
||||||
c.supervisions[0].normalized_text = c.supervisions[
|
|
||||||
0
|
|
||||||
].text
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"{args.prefix}")
|
|
||||||
c.supervisions[0].custom["tokens"] = {"text": phonemes}
|
|
||||||
unique_symbols.update(phonemes)
|
|
||||||
c.tokens = phonemes
|
c.tokens = phonemes
|
||||||
assert c.supervisions[
|
assert c.supervisions[
|
||||||
0
|
0
|
||||||
|
@ -5,13 +5,12 @@ set -eou pipefail
|
|||||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||||
|
|
||||||
j=16
|
stage=1
|
||||||
stage=2
|
stop_stage=4
|
||||||
stop_stage=2
|
|
||||||
|
|
||||||
dl_dir=$PWD/download
|
dl_dir=$PWD/download
|
||||||
|
|
||||||
dataset_parts="-p Basic" # -p Premium for Premium dataset only
|
dataset_parts="Premium" # Basic for all 10k hours data, Premium for about 10% of the data
|
||||||
|
|
||||||
text_extractor="pypinyin_initials_finals" # default is espeak for English
|
text_extractor="pypinyin_initials_finals" # default is espeak for English
|
||||||
audio_extractor="Encodec" # or Fbank
|
audio_extractor="Encodec" # or Fbank
|
||||||
@ -62,37 +61,37 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \
|
python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \
|
||||||
--text-extractor ${text_extractor} \
|
--text-extractor ${text_extractor} \
|
||||||
--audio-extractor ${audio_extractor} \
|
--audio-extractor ${audio_extractor} \
|
||||||
--batch-duration 2500 \
|
--batch-duration 2500 --prefix "wenetspeech4tts" \
|
||||||
--prefix "wenetspeech4tts" \
|
|
||||||
--src-dir "data/manifests" \
|
--src-dir "data/manifests" \
|
||||||
--split 100 \
|
--split 100 \
|
||||||
--output-dir "${audio_feats_dir}/${prefix}_baisc_split_100"
|
--output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100"
|
||||||
|
cp ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100/unique_text_tokens.k2symbols ${audio_feats_dir}
|
||||||
fi
|
fi
|
||||||
touch ${audio_feats_dir}/.wenetspeech4tts.tokenize.done
|
touch ${audio_feats_dir}/.wenetspeech4tts.tokenize.done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
log "Stage 13: Combine features for basic"
|
log "Stage 3: Combine features"
|
||||||
if [ ! -f ${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz ]; then
|
if [ ! -f ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz ]; then
|
||||||
pieces=$(find ${audio_feats_dir}/wenetspeech4tts_baisc_split_100 -name "*.jsonl.gz")
|
pieces=$(find ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100 -name "*.jsonl.gz")
|
||||||
lhotse combine $pieces ${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz
|
lhotse combine $pieces ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
log "Stage 3: Prepare wenetspeech4tts train/dev/test"
|
log "Stage 4: Prepare wenetspeech4tts train/dev/test"
|
||||||
if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.train.done ]; then
|
if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.train.done ]; then
|
||||||
|
|
||||||
lhotse subset --first 400 \
|
lhotse subset --first 400 \
|
||||||
${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz \
|
${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \
|
||||||
${audio_feats_dir}/cuts_dev.jsonl.gz
|
${audio_feats_dir}/cuts_dev.jsonl.gz
|
||||||
|
|
||||||
lhotse subset --last 400 \
|
lhotse subset --last 400 \
|
||||||
${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz \
|
${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \
|
||||||
${audio_feats_dir}/cuts_test.jsonl.gz
|
${audio_feats_dir}/cuts_test.jsonl.gz
|
||||||
|
|
||||||
lhotse copy \
|
lhotse copy \
|
||||||
${audio_feats_dir}/wenetspeech4tts_cuts_Baisc.jsonl.gz \
|
${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \
|
||||||
${audio_feats_dir}/cuts_train.jsonl.gz
|
${audio_feats_dir}/cuts_train.jsonl.gz
|
||||||
|
|
||||||
touch ${audio_feats_dir}/.wenetspeech4tts.train.done
|
touch ${audio_feats_dir}/.wenetspeech4tts.train.done
|
||||||
|
@ -14,21 +14,20 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Phonemize Text and EnCodec Audio.
|
This script is used to synthesize speech from text prompts and audio prompts.
|
||||||
|
|
||||||
Usage example:
|
Usage example:
|
||||||
python3 bin/infer.py --output-dir demos_epoch_${epoch}_avg_${avg} \
|
python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg} \
|
||||||
--checkpoint=${exp_dir}/epoch-${epoch}-avg-${avg}.pt \
|
--checkpoint=${exp_dir}/epoch-${epoch}-avg-${avg}.pt \
|
||||||
--text-prompts "KNOT one point one five miles per hour." \
|
--text-prompts "KNOT one point one five miles per hour." \
|
||||||
--audio-prompts ./prompts/8463_294825_000043_000000.wav \
|
--audio-prompts ./prompts/8463_294825_000043_000000.wav \
|
||||||
--text "To get up and running quickly just follow the steps below."
|
--text "To get up and running quickly just follow the steps below."
|
||||||
|
|
||||||
python3 bin/infer.py --output-dir demos_epoch_${epoch}_avg_${avg} \
|
top_p=1.0
|
||||||
|
python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \
|
||||||
--top-k -1 --temperature 1.0 \
|
--top-k -1 --temperature 1.0 \
|
||||||
--text-prompts "" \
|
--text ./aishell3.txt \
|
||||||
--audio-prompts "" \
|
--checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt \
|
||||||
--text ./libritts.txt \
|
--text-extractor pypinyin_initials_finals --top-p ${top_p}
|
||||||
--checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
@ -43,9 +42,9 @@ import torchaudio
|
|||||||
from compute_neural_codec_and_prepare_text_tokens import (
|
from compute_neural_codec_and_prepare_text_tokens import (
|
||||||
AudioTokenizer,
|
AudioTokenizer,
|
||||||
TextTokenizer,
|
TextTokenizer,
|
||||||
tokenize_audio,
|
|
||||||
tokenize_text,
|
tokenize_text,
|
||||||
)
|
)
|
||||||
|
from encodec.utils import convert_audio
|
||||||
from k2 import symbol_table
|
from k2 import symbol_table
|
||||||
from tokenizer import get_text_token_collater
|
from tokenizer import get_text_token_collater
|
||||||
from valle import VALLE
|
from valle import VALLE
|
||||||
@ -71,7 +70,7 @@ def get_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--manifest",
|
"--text",
|
||||||
type=str,
|
type=str,
|
||||||
default="",
|
default="",
|
||||||
help="prompt text\t prompt audio\ttarget text\ttarget audio",
|
help="prompt text\t prompt audio\ttarget text\ttarget audio",
|
||||||
@ -126,6 +125,13 @@ def get_args():
|
|||||||
help="Do continual task.",
|
help="Do continual task.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--repetition-aware-sampling",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether AR Decoder do valle-2 repetition-aware sampling. https://arxiv.org/pdf/2406.05370",
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -159,15 +165,28 @@ def load_model(checkpoint, device):
|
|||||||
return model, params.text_tokens
|
return model, params.text_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str):
|
||||||
|
# Load and pre-process the audio waveform
|
||||||
|
wav, sr = torchaudio.load(audio_path)
|
||||||
|
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
|
||||||
|
wav = wav.unsqueeze(0)
|
||||||
|
|
||||||
|
# Extract discrete codes from EnCodec
|
||||||
|
with torch.no_grad():
|
||||||
|
encoded_frames = tokenizer.encode(wav)
|
||||||
|
return encoded_frames
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
text_tokenizer = TextTokenizer(backend=args.text_extractor)
|
text_tokenizer = TextTokenizer(backend=args.text_extractor)
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", 0)
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
model, text_tokens = load_model(args.checkpoint, device)
|
model, text_tokens = load_model(args.checkpoint, device)
|
||||||
|
|
||||||
text_collater = get_text_token_collater(text_tokens)
|
text_collater = get_text_token_collater(text_tokens)
|
||||||
|
|
||||||
audio_tokenizer = AudioTokenizer()
|
audio_tokenizer = AudioTokenizer()
|
||||||
@ -194,8 +213,7 @@ def main():
|
|||||||
# https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py
|
# https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py
|
||||||
with open(args.text) as f:
|
with open(args.text) as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
fields = line.strip().split("\t")
|
fields = line.strip().split(" ")
|
||||||
# fields = line.strip().split(" ")
|
|
||||||
fields = [item for item in fields if item]
|
fields = [item for item in fields if item]
|
||||||
assert len(fields) == 4
|
assert len(fields) == 4
|
||||||
prompt_text, prompt_audio, text, audio_path = fields
|
prompt_text, prompt_audio, text, audio_path = fields
|
||||||
@ -223,6 +241,7 @@ def main():
|
|||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
|
ras=args.repetition_aware_sampling,
|
||||||
)
|
)
|
||||||
|
|
||||||
samples = audio_tokenizer.decode(
|
samples = audio_tokenizer.decode(
|
||||||
@ -264,6 +283,7 @@ def main():
|
|||||||
top_k=args.top_k,
|
top_k=args.top_k,
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
|
ras=args.repetition_aware_sampling,
|
||||||
)
|
)
|
||||||
|
|
||||||
if audio_prompts != []:
|
if audio_prompts != []:
|
||||||
@ -274,11 +294,6 @@ def main():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
torch._C._jit_set_profiling_executor(False)
|
|
||||||
torch._C._jit_set_profiling_mode(False)
|
|
||||||
torch._C._set_graph_executor_optimize(False)
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
@ -24,9 +24,9 @@ world_size=8
|
|||||||
exp_dir=exp/valle
|
exp_dir=exp/valle
|
||||||
|
|
||||||
## Train AR model
|
## Train AR model
|
||||||
python3 train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \
|
python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \
|
||||||
--num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \
|
--num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \
|
||||||
--model-name valle --share-embedding true --norm-first true --add-prenet false \
|
--share-embedding true --norm-first true --add-prenet false \
|
||||||
--decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \
|
--decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \
|
||||||
--base-lr 0.03 --warmup-steps 200 --average-period 0 \
|
--base-lr 0.03 --warmup-steps 200 --average-period 0 \
|
||||||
--num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \
|
--num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \
|
||||||
@ -36,9 +36,9 @@ python3 train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-durat
|
|||||||
# cd ${exp_dir}
|
# cd ${exp_dir}
|
||||||
# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1
|
# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1
|
||||||
# cd -
|
# cd -
|
||||||
python3 train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \
|
python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \
|
||||||
--num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \
|
--num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \
|
||||||
--model-name valle --share-embedding true --norm-first true --add-prenet false \
|
--share-embedding true --norm-first true --add-prenet false \
|
||||||
--decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \
|
--decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \
|
||||||
--base-lr 0.03 --warmup-steps 200 --average-period 0 \
|
--base-lr 0.03 --warmup-steps 200 --average-period 0 \
|
||||||
--num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \
|
--num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \
|
||||||
@ -1032,30 +1032,10 @@ def run(rank, world_size, args):
|
|||||||
model_parameters = model.parameters()
|
model_parameters = model.parameters()
|
||||||
|
|
||||||
if params.optimizer_name == "ScaledAdam":
|
if params.optimizer_name == "ScaledAdam":
|
||||||
parameters_names = []
|
|
||||||
if params.train_stage: # != 0
|
|
||||||
_model = model.module if isinstance(model, DDP) else model
|
|
||||||
parameters_names.append(
|
|
||||||
[
|
|
||||||
name_param_pair[0]
|
|
||||||
for name_param_pair in _model.stage_named_parameters(
|
|
||||||
params.train_stage
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
parameters_names.append(
|
|
||||||
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer = ScaledAdam(
|
optimizer = ScaledAdam(
|
||||||
model_parameters,
|
model_parameters,
|
||||||
lr=params.base_lr,
|
lr=params.base_lr,
|
||||||
betas=(0.9, 0.95),
|
|
||||||
clipping_scale=2.0,
|
clipping_scale=2.0,
|
||||||
parameters_names=parameters_names,
|
|
||||||
show_dominant_parameters=False,
|
|
||||||
clipping_update_period=1000,
|
|
||||||
)
|
)
|
||||||
elif params.optimizer_name == "AdamW":
|
elif params.optimizer_name == "AdamW":
|
||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(
|
||||||
@ -1112,7 +1092,7 @@ def run(rank, world_size, args):
|
|||||||
train_dl = dataset.train_dataloaders(
|
train_dl = dataset.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
valid_dl = dataset.valid_dataloaders(valid_cuts)
|
valid_dl = dataset.dev_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if params.oom_check:
|
if params.oom_check:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
|
@ -195,8 +195,8 @@ class TtsDataModule:
|
|||||||
"""
|
"""
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
train = SpeechSynthesisDataset(
|
train = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=True,
|
||||||
return_tokens=False,
|
return_tokens=True,
|
||||||
return_spk_ids=False,
|
return_spk_ids=False,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
@ -251,8 +251,8 @@ class TtsDataModule:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
validate = SpeechSynthesisDataset(
|
validate = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=True,
|
||||||
return_tokens=False,
|
return_tokens=True,
|
||||||
return_spk_ids=False,
|
return_spk_ids=False,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
@ -279,8 +279,8 @@ class TtsDataModule:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
test = SpeechSynthesisDataset(
|
test = SpeechSynthesisDataset(
|
||||||
return_text=False,
|
return_text=True,
|
||||||
return_tokens=False,
|
return_tokens=True,
|
||||||
return_spk_ids=False,
|
return_spk_ids=False,
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
|
@ -31,8 +31,8 @@ from torchmetrics.classification import MulticlassAccuracy
|
|||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
|
NUM_TEXT_TOKENS = 5000
|
||||||
from .visualizer import visualize
|
NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins
|
||||||
|
|
||||||
|
|
||||||
class PromptedFeatures:
|
class PromptedFeatures:
|
||||||
@ -1194,15 +1194,6 @@ class VALLE(nn.Module):
|
|||||||
|
|
||||||
return y_emb, prefix_len
|
return y_emb, prefix_len
|
||||||
|
|
||||||
def visualize(
|
|
||||||
self,
|
|
||||||
predicts: Tuple[torch.Tensor],
|
|
||||||
batch: Dict[str, Union[List, torch.Tensor]],
|
|
||||||
output_dir: str,
|
|
||||||
limit: int = 4,
|
|
||||||
) -> None:
|
|
||||||
visualize(predicts, batch, output_dir, limit=limit)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -1404,6 +1395,7 @@ class VALLE(nn.Module):
|
|||||||
top_k: int = -100,
|
top_k: int = -100,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
|
ras: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -1418,6 +1410,8 @@ class VALLE(nn.Module):
|
|||||||
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
|
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
|
||||||
temperature: (`optional`) float
|
temperature: (`optional`) float
|
||||||
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
||||||
|
ras: (`optional`) bool
|
||||||
|
Whether to use repetition-aware sampling. Default to False.
|
||||||
Returns:
|
Returns:
|
||||||
Return the predicted audio code matrix.
|
Return the predicted audio code matrix.
|
||||||
"""
|
"""
|
||||||
@ -1473,7 +1467,6 @@ class VALLE(nn.Module):
|
|||||||
mask=xy_attn_mask,
|
mask=xy_attn_mask,
|
||||||
)
|
)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
ras = True
|
|
||||||
samples = topk_sampling(
|
samples = topk_sampling(
|
||||||
logits,
|
logits,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@ -1490,8 +1483,6 @@ class VALLE(nn.Module):
|
|||||||
):
|
):
|
||||||
if prompts.shape[1] == y.shape[1]:
|
if prompts.shape[1] == y.shape[1]:
|
||||||
raise SyntaxError("well trained model shouldn't reach here.")
|
raise SyntaxError("well trained model shouldn't reach here.")
|
||||||
|
|
||||||
print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
|
|
||||||
break
|
break
|
||||||
|
|
||||||
y = torch.concat([y, samples], dim=1)
|
y = torch.concat([y, samples], dim=1)
|
||||||
@ -1716,24 +1707,14 @@ def topk_sampling(
|
|||||||
repetition_aware_sampling=False,
|
repetition_aware_sampling=False,
|
||||||
preceding_tokens=None,
|
preceding_tokens=None,
|
||||||
):
|
):
|
||||||
# temperature: (`optional`) float
|
|
||||||
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
|
||||||
# top_k: (`optional`) int
|
|
||||||
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
|
||||||
# top_p: (`optional`) float
|
|
||||||
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
|
||||||
|
|
||||||
# Temperature (higher temperature => more likely to sample low probability tokens)
|
|
||||||
if temperature != 1.0:
|
if temperature != 1.0:
|
||||||
logits = logits / temperature
|
logits = logits / temperature
|
||||||
# Top-p/top-k filtering
|
# Top-p/top-k filtering
|
||||||
logits = top_k_top_p_filtering(
|
logits_filtered = top_k_top_p_filtering(
|
||||||
logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
||||||
)
|
)
|
||||||
# Sample
|
# Sample
|
||||||
probs = F.softmax(logits, dim=-1)
|
probs = F.softmax(logits_filtered, dim=-1)
|
||||||
# print top 10 value and index
|
|
||||||
print("top 10 value and index", torch.topk(probs, 10), top_p)
|
|
||||||
tokens = torch.multinomial(probs, num_samples=1)
|
tokens = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
if repetition_aware_sampling:
|
if repetition_aware_sampling:
|
||||||
@ -1758,16 +1739,7 @@ def topk_sampling(
|
|||||||
# check if the repeat ratio exceeds the threshold
|
# check if the repeat ratio exceeds the threshold
|
||||||
if (item == tokens[i]).sum() / window_size > threshold:
|
if (item == tokens[i]).sum() / window_size > threshold:
|
||||||
# replace the target code ct′ by random sampling
|
# replace the target code ct′ by random sampling
|
||||||
# make sure we don't sample the same token, by setting the probability of the token to 0
|
|
||||||
# logits[i][tokens[i]] = -float("Inf")
|
|
||||||
probs = F.softmax(logits[i], dim=-1)
|
probs = F.softmax(logits[i], dim=-1)
|
||||||
token_new = torch.multinomial(probs, num_samples=1)
|
token_new = torch.multinomial(probs, num_samples=1)
|
||||||
|
|
||||||
print(
|
|
||||||
f"Repetition Aware Sampling: {item}, {tokens[i]} -> {token_new}"
|
|
||||||
)
|
|
||||||
print("probs", probs, logits.shape)
|
|
||||||
tokens[i] = token_new
|
tokens[i] = token_new
|
||||||
else:
|
|
||||||
print(f"Not trigger: {i}, {item}, {tokens[i]}")
|
|
||||||
return tokens
|
return tokens
|
||||||
|
Loading…
x
Reference in New Issue
Block a user