add tensorboard samples

This commit is contained in:
zr_jin 2024-11-04 12:22:17 +08:00
parent d874bdedac
commit 5d79953f22
2 changed files with 35 additions and 5 deletions

View File

@ -56,7 +56,7 @@ function infer() {
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
./matcha/inference.py \ ./matcha/infer.py \
--epoch 1 \ --epoch 1 \
--exp-dir ./matcha/exp \ --exp-dir ./matcha/exp \
--tokens data/tokens.txt \ --tokens data/tokens.txt \

View File

@ -7,9 +7,10 @@ import json
import logging import logging
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Tuple, Union
import k2 import k2
import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
@ -343,7 +344,7 @@ def compute_validation_loss(
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
) -> MetricsTracker: ) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
"""Run the validation process.""" """Run the validation process."""
model.eval() model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device
@ -391,6 +392,23 @@ def compute_validation_loss(
# summary stats # summary stats
tot_loss = tot_loss + loss_info tot_loss = tot_loss + loss_info
# infer for first batch:
if batch_idx == 0 and rank == 0:
inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference(
text=tokens[0, : tokens_lens[0].item()]
)
audio_pred = audio_pred.data.cpu().numpy()
audio_len_pred = (
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
)
assert audio_len_pred == len(audio_pred), (
audio_len_pred,
len(audio_pred),
)
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
returned_sample = (audio_pred, audio_gt)
if world_size > 1: if world_size > 1:
tot_loss.reduce(device) tot_loss.reduce(device)
@ -399,7 +417,7 @@ def compute_validation_loss(
params.best_valid_epoch = params.cur_epoch params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value params.best_valid_loss = loss_value
return tot_loss return tot_loss, returned_sample
def train_one_epoch( def train_one_epoch(
@ -559,7 +577,7 @@ def train_one_epoch(
if params.batch_idx_train % params.valid_interval == 1: if params.batch_idx_train % params.valid_interval == 1:
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info = compute_validation_loss( valid_info, (speech_hat, speech) = compute_validation_loss(
params=params, params=params,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -577,6 +595,18 @@ def train_one_epoch(
valid_info.write_summary( valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
tb_writer.add_audio(
"train/valid_speech_hat",
speech_hat,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
"train/valid_speech",
speech,
params.batch_idx_train,
params.sampling_rate,
)
loss_value = tot_loss["tot_loss"] / tot_loss["samples"] loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
params.train_loss = loss_value params.train_loss = loss_value