diff --git a/egs/wenetspeech4tts/TTS/valle/train.py b/egs/wenetspeech4tts/TTS/valle/train.py index 27b947b77..e9ec548f3 100755 --- a/egs/wenetspeech4tts/TTS/valle/train.py +++ b/egs/wenetspeech4tts/TTS/valle/train.py @@ -4,6 +4,7 @@ # Mingshuang Luo) # Copyright 2023 (authors: Feiteng Li) # Copyright 2024 (authors: Yuekai Zhang) +# Copyright 2024 Tsinghua University (authors: Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -48,10 +49,8 @@ python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max import argparse import copy import logging -import os import random import warnings -from contextlib import nullcontext from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union @@ -686,9 +685,9 @@ def compute_validation_loss( output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}") output_dir.mkdir(parents=True, exist_ok=True) if isinstance(model, DDP): - model.module.visualize(predicts, batch, output_dir=output_dir) + model.module.visualize(predicts, batch, tokenizer, output_dir=output_dir) else: - model.visualize(predicts, batch, output_dir=output_dir) + model.visualize(predicts, batch, tokenizer, output_dir=output_dir) return tot_loss diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py index 40501736b..206b843ba 100644 --- a/egs/wenetspeech4tts/TTS/valle/valle.py +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -23,6 +23,7 @@ import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn +from tokenizer import TextTokenCollater from torch import Tensor from torch.nn import Linear, Module from torch.nn import functional as F @@ -1664,13 +1665,15 @@ class VALLE(nn.Module): self, predicts: Tuple[torch.Tensor], batch: Dict[str, Union[List, torch.Tensor]], + tokenizer: TextTokenCollater, output_dir: str, limit: int = 4, ) -> None: - text_tokens = batch["text_tokens"].to("cpu").detach().numpy() - text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy() audio_features = batch["audio_features"].to("cpu").detach().numpy() audio_features_lens = batch["audio_features_lens"].to("cpu").detach().numpy() + + tokens = batch["tokens"] + text_tokens, text_tokens_lens = tokenizer(tokens) assert text_tokens.ndim == 2 utt_ids, texts = batch["utt_id"], batch["text"]