diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py index fee66da48..5152ae675 100755 --- a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py @@ -38,7 +38,7 @@ from lhotse.audio import RecordingSet from lhotse.features.base import FeatureExtractor, register_extractor from lhotse.supervision import SupervisionSet from lhotse.utils import Seconds, compute_num_frames -from matcha.utils.audio import mel_spectrogram +from matcha.audio import mel_spectrogram from icefall.utils import get_executor diff --git a/egs/ljspeech/TTS/matcha/export_onnx.py b/egs/ljspeech/TTS/matcha/export_onnx.py index cf5069b11..c0eebcde0 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx.py +++ b/egs/ljspeech/TTS/matcha/export_onnx.py @@ -73,8 +73,6 @@ class ModelWrapper(torch.nn.Module): )["mel"] # mel: (batch_size, feat_dim, num_frames) - # audio = self.vocoder(mel).clamp(-1, 1).squeeze(1) - return mel diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py index 89a6b33ae..8fc0ec3ac 100755 --- a/egs/ljspeech/TTS/matcha/inference.py +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -28,7 +28,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=2810, + default=4000, help="""It specifies the checkpoint to use for decoding. Note: Epoch counts from 1. """, @@ -44,6 +44,13 @@ def get_parser(): """, ) + parser.add_argument( + "--vocoder", + type=Path, + default="./generator_v1", + help="Path to the vocoder", + ) + parser.add_argument( "--tokens", type=Path, @@ -61,6 +68,7 @@ def get_parser(): def load_vocoder(checkpoint_path): + checkpoint_path = str(checkpoint_path) if checkpoint_path.endswith("v1"): h = AttributeDict(v1) elif checkpoint_path.endswith("v2"): @@ -142,10 +150,17 @@ def main(): logging.info("About to create model") model = get_model(params) + + if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file(): + raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist") + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) model.eval() - vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1") + if not Path(params.vocoder).is_file(): + raise ValueError(f"{params.vocoder} does not exist") + + vocoder = load_vocoder(params.vocoder) denoiser = Denoiser(vocoder, mode="zeros") texts = [ diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 308a06b1f..d31ce1301 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -90,7 +90,7 @@ def save_checkpoint( if params: for k, v in params.items(): - assert k not in checkpoint + assert k not in checkpoint, k checkpoint[k] = v torch.save(checkpoint, filename)