Update infer.py

This commit is contained in:
zr_jin 2024-11-05 11:40:23 +08:00
parent 8058988cac
commit 8c6141e6d3

View File

@ -93,13 +93,13 @@ def to_waveform(
) -> torch.Tensor:
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze()
return audio.squeeze()
def process_text(text: str, tokenizer: Tokenizer) -> dict:
def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict:
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.long)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
x = torch.tensor(x, dtype=torch.long, device=device)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
@ -110,9 +110,10 @@ def synthesise(
text: str,
length_scale: float,
temperature: float,
device: str = "cpu",
spks=None,
) -> dict:
text_processed = process_text(text, tokenizer)
text_processed = process_text(text=text, tokenizer=tokenizer, device=device)
start_t = dt.datetime.now()
output = model.synthesise(
text_processed["x"],
@ -161,7 +162,7 @@ def infer_dataset(
for batch_idx, batch in enumerate(dl):
batch_size = len(batch["tokens"])
texts = batch["supervisions"]["text"]
texts = [c.supervisions[0].normalized_text for c in batch["cut"]]
audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
@ -175,6 +176,7 @@ def infer_dataset(
text=texts[i],
length_scale=params.length_scale,
temperature=params.temperature,
device=device,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)