add tensorboard samples

This commit is contained in:
zr_jin 2024-11-04 12:22:17 +08:00
parent d874bdedac
commit 5d79953f22
2 changed files with 35 additions and 5 deletions

View File

@ -56,7 +56,7 @@ function infer() {
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
./matcha/inference.py \
./matcha/infer.py \
--epoch 1 \
--exp-dir ./matcha/exp \
--tokens data/tokens.txt \

View File

@ -7,9 +7,10 @@ import json
import logging
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Tuple, Union
import k2
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.nn as nn
@ -343,7 +344,7 @@ def compute_validation_loss(
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
rank: int = 0,
) -> MetricsTracker:
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
"""Run the validation process."""
model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
@ -391,6 +392,23 @@ def compute_validation_loss(
# summary stats
tot_loss = tot_loss + loss_info
# infer for first batch:
if batch_idx == 0 and rank == 0:
inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference(
text=tokens[0, : tokens_lens[0].item()]
)
audio_pred = audio_pred.data.cpu().numpy()
audio_len_pred = (
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
)
assert audio_len_pred == len(audio_pred), (
audio_len_pred,
len(audio_pred),
)
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
returned_sample = (audio_pred, audio_gt)
if world_size > 1:
tot_loss.reduce(device)
@ -399,7 +417,7 @@ def compute_validation_loss(
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
return tot_loss, returned_sample
def train_one_epoch(
@ -559,7 +577,7 @@ def train_one_epoch(
if params.batch_idx_train % params.valid_interval == 1:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
valid_info, (speech_hat, speech) = compute_validation_loss(
params=params,
model=model,
tokenizer=tokenizer,
@ -577,6 +595,18 @@ def train_one_epoch(
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_audio(
"train/valid_speech_hat",
speech_hat,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
"train/valid_speech",
speech,
params.batch_idx_train,
params.sampling_rate,
)
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
params.train_loss = loss_value