diff --git a/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py b/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py index 57f677fcb..7d42a00a5 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py @@ -82,9 +82,12 @@ class SpeechSynthesisDataset(torch.utils.data.Dataset): text = [cut.supervisions[0].text for cut in cuts] batch["text"] = text - if self.return_tokens: + if self.return_tokens and "speech_tokens" in cuts[0].supervisions[0].custom: # tokens = [cut.tokens for cut in cuts] - tokens = [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts] + # tokens = [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts] + tokens = [cut.supervisions[0].custom["speech_tokens"] for cut in cuts] + # change str into list + tokens = [list(map(int, token.split())) for token in tokens] batch["tokens"] = tokens if self.return_spk_ids: diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py index 37dcf531e..61b1c709c 100755 --- a/egs/wenetspeech4tts/TTS/f5-tts/train.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -379,6 +379,8 @@ def get_tokenizer(vocab_file_path: str): def get_model(params): vocab_char_map, vocab_size = get_tokenizer(params.tokens) + vocab_char_map, vocab_size = None, 6561 + # https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/file/view/master?fileName=cosyvoice.yaml&status=1#L36 # bigvgan 100 dim features n_mel_channels = 100 n_fft = 1024 @@ -556,14 +558,38 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) +def insert_zeros_optimized(arr): + # cosyvoice, 25 tokens/sec + # bigvgan sample_rate/hop_length 24000/256 frames/sec + # For every 4 cosyvoice tokens, insert pad tokens to extend it to 15 tokens to match bigvgan frames length + # We choose 4,4,4,3 to match 15 frames + three, two = [-1] * 3, [-1] * 2 + return [ + x for i, e in enumerate(arr) for x in ([e] + three if i % 4 < 3 else [e] + two) + ] + + def prepare_input(batch: dict, device: torch.device): """Parse batch data""" - text_inputs = batch["text"] - # texts.extend(convert_char_to_pinyin([text], polyphone=true)) - text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True) + # text_inputs = batch["text"] + # text_inputs = convert_char_to_pinyin(text_inputs, polyphone=True) mel_spec = batch["features"] mel_lengths = batch["features_lens"] + + semantic_tokens = [] + for i in range(len(batch["tokens"])): + tokens = batch["tokens"][i] + tokens = insert_zeros_optimized(tokens) + semantic_tokens.append(tokens) + # pad to the same length, B,T, with pad value -1 + max_len = max([len(tokens) for tokens in semantic_tokens]) + text_inputs = torch.full((len(semantic_tokens), max_len), -1, dtype=torch.long).to( + device + ) + for i, tokens in enumerate(semantic_tokens): + text_inputs[i, : len(tokens)] = torch.tensor(tokens, dtype=torch.long) + return text_inputs, mel_spec.to(device), mel_lengths.to(device) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py b/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py index 80ba17318..eab7588b7 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py @@ -174,7 +174,7 @@ class TtsDataModule: logging.info("About to create train dataset") train = SpeechSynthesisDataset( return_text=True, - return_tokens=False, + return_tokens=True, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) @@ -234,7 +234,7 @@ class TtsDataModule: else: validate = SpeechSynthesisDataset( return_text=True, - return_tokens=False, + return_tokens=True, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) @@ -265,7 +265,7 @@ class TtsDataModule: else: test = SpeechSynthesisDataset( return_text=True, - return_tokens=False, + return_tokens=True, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) diff --git a/egs/wenetspeech4tts/TTS/prepare.sh b/egs/wenetspeech4tts/TTS/prepare.sh index cf86b3fa5..b1cc4eb10 100755 --- a/egs/wenetspeech4tts/TTS/prepare.sh +++ b/egs/wenetspeech4tts/TTS/prepare.sh @@ -130,6 +130,9 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then data/fbank/${prefix}_cuts_validtest.jsonl.gz \ data/fbank/${prefix}_cuts_test.jsonl.gz + + # zcat "data/fbank/${prefix}_cuts_${subset}.jsonl.gz" | head -n 100 | gzip > "data/fbank/${prefix}_cuts_${subset}_top100.jsonl.gz" + rm data/fbank/${prefix}_cuts_validtest.jsonl.gz n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 ))