From 8c6141e6d34a2763e62ae1e3d1c342c130046ee7 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 5 Nov 2024 11:40:23 +0800 Subject: [PATCH] Update infer.py --- egs/ljspeech/TTS/matcha/infer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/egs/ljspeech/TTS/matcha/infer.py b/egs/ljspeech/TTS/matcha/infer.py index 869ed20f3..5b1332b97 100755 --- a/egs/ljspeech/TTS/matcha/infer.py +++ b/egs/ljspeech/TTS/matcha/infer.py @@ -93,13 +93,13 @@ def to_waveform( ) -> torch.Tensor: audio = vocoder(mel).clamp(-1, 1) audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() - return audio.cpu().squeeze() + return audio.squeeze() -def process_text(text: str, tokenizer: Tokenizer) -> dict: +def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict: x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) - x = torch.tensor(x, dtype=torch.long) - x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu") + x = torch.tensor(x, dtype=torch.long, device=device) + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) return {"x_orig": text, "x": x, "x_lengths": x_lengths} @@ -110,9 +110,10 @@ def synthesise( text: str, length_scale: float, temperature: float, + device: str = "cpu", spks=None, ) -> dict: - text_processed = process_text(text, tokenizer) + text_processed = process_text(text=text, tokenizer=tokenizer, device=device) start_t = dt.datetime.now() output = model.synthesise( text_processed["x"], @@ -161,7 +162,7 @@ def infer_dataset( for batch_idx, batch in enumerate(dl): batch_size = len(batch["tokens"]) - texts = batch["supervisions"]["text"] + texts = [c.supervisions[0].normalized_text for c in batch["cut"]] audio = batch["audio"] audio_lens = batch["audio_lens"].tolist() @@ -175,6 +176,7 @@ def infer_dataset( text=texts[i], length_scale=params.length_scale, temperature=params.temperature, + device=device, ) output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)