mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Update infer.py
This commit is contained in:
parent
8058988cac
commit
8c6141e6d3
@ -93,13 +93,13 @@ def to_waveform(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
audio = vocoder(mel).clamp(-1, 1)
|
audio = vocoder(mel).clamp(-1, 1)
|
||||||
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
|
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
|
||||||
return audio.cpu().squeeze()
|
return audio.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def process_text(text: str, tokenizer: Tokenizer) -> dict:
|
def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict:
|
||||||
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
||||||
x = torch.tensor(x, dtype=torch.long)
|
x = torch.tensor(x, dtype=torch.long, device=device)
|
||||||
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
|
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
|
||||||
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
|
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
|
||||||
|
|
||||||
|
|
||||||
@ -110,9 +110,10 @@ def synthesise(
|
|||||||
text: str,
|
text: str,
|
||||||
length_scale: float,
|
length_scale: float,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
|
device: str = "cpu",
|
||||||
spks=None,
|
spks=None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
text_processed = process_text(text, tokenizer)
|
text_processed = process_text(text=text, tokenizer=tokenizer, device=device)
|
||||||
start_t = dt.datetime.now()
|
start_t = dt.datetime.now()
|
||||||
output = model.synthesise(
|
output = model.synthesise(
|
||||||
text_processed["x"],
|
text_processed["x"],
|
||||||
@ -161,7 +162,7 @@ def infer_dataset(
|
|||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
batch_size = len(batch["tokens"])
|
batch_size = len(batch["tokens"])
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = [c.supervisions[0].normalized_text for c in batch["cut"]]
|
||||||
|
|
||||||
audio = batch["audio"]
|
audio = batch["audio"]
|
||||||
audio_lens = batch["audio_lens"].tolist()
|
audio_lens = batch["audio_lens"].tolist()
|
||||||
@ -175,6 +176,7 @@ def infer_dataset(
|
|||||||
text=texts[i],
|
text=texts[i],
|
||||||
length_scale=params.length_scale,
|
length_scale=params.length_scale,
|
||||||
temperature=params.temperature,
|
temperature=params.temperature,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user