mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 04:22:21 +00:00
add tensorboard samples
This commit is contained in:
parent
d874bdedac
commit
5d79953f22
2
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
2
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
@ -56,7 +56,7 @@ function infer() {
|
||||
|
||||
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
|
||||
|
||||
./matcha/inference.py \
|
||||
./matcha/infer.py \
|
||||
--epoch 1 \
|
||||
--exp-dir ./matcha/exp \
|
||||
--tokens data/tokens.txt \
|
||||
|
@ -7,9 +7,10 @@ import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
@ -343,7 +344,7 @@ def compute_validation_loss(
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> MetricsTracker:
|
||||
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
@ -391,6 +392,23 @@ def compute_validation_loss(
|
||||
# summary stats
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
# infer for first batch:
|
||||
if batch_idx == 0 and rank == 0:
|
||||
inner_model = model.module if isinstance(model, DDP) else model
|
||||
audio_pred, _, duration = inner_model.inference(
|
||||
text=tokens[0, : tokens_lens[0].item()]
|
||||
)
|
||||
audio_pred = audio_pred.data.cpu().numpy()
|
||||
audio_len_pred = (
|
||||
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
||||
)
|
||||
assert audio_len_pred == len(audio_pred), (
|
||||
audio_len_pred,
|
||||
len(audio_pred),
|
||||
)
|
||||
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
|
||||
returned_sample = (audio_pred, audio_gt)
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(device)
|
||||
|
||||
@ -399,7 +417,7 @@ def compute_validation_loss(
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
return tot_loss, returned_sample
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
@ -559,7 +577,7 @@ def train_one_epoch(
|
||||
|
||||
if params.batch_idx_train % params.valid_interval == 1:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
valid_info, (speech_hat, speech) = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@ -577,6 +595,18 @@ def train_one_epoch(
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/valid_speech_hat",
|
||||
speech_hat,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/valid_speech",
|
||||
speech,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
|
||||
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
|
||||
params.train_loss = loss_value
|
||||
|
Loading…
x
Reference in New Issue
Block a user