Update infer.py

This commit is contained in:
zr_jin 2024-11-05 11:40:23 +08:00
parent 8058988cac
commit 8c6141e6d3

View File

@ -93,13 +93,13 @@ def to_waveform(
) -> torch.Tensor: ) -> torch.Tensor:
audio = vocoder(mel).clamp(-1, 1) audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze() return audio.squeeze()
def process_text(text: str, tokenizer: Tokenizer) -> dict: def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict:
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
x = torch.tensor(x, dtype=torch.long) x = torch.tensor(x, dtype=torch.long, device=device)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu") x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
return {"x_orig": text, "x": x, "x_lengths": x_lengths} return {"x_orig": text, "x": x, "x_lengths": x_lengths}
@ -110,9 +110,10 @@ def synthesise(
text: str, text: str,
length_scale: float, length_scale: float,
temperature: float, temperature: float,
device: str = "cpu",
spks=None, spks=None,
) -> dict: ) -> dict:
text_processed = process_text(text, tokenizer) text_processed = process_text(text=text, tokenizer=tokenizer, device=device)
start_t = dt.datetime.now() start_t = dt.datetime.now()
output = model.synthesise( output = model.synthesise(
text_processed["x"], text_processed["x"],
@ -161,7 +162,7 @@ def infer_dataset(
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
batch_size = len(batch["tokens"]) batch_size = len(batch["tokens"])
texts = batch["supervisions"]["text"] texts = [c.supervisions[0].normalized_text for c in batch["cut"]]
audio = batch["audio"] audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist() audio_lens = batch["audio_lens"].tolist()
@ -175,6 +176,7 @@ def infer_dataset(
text=texts[i], text=texts[i],
length_scale=params.length_scale, length_scale=params.length_scale,
temperature=params.temperature, temperature=params.temperature,
device=device,
) )
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)