From 5d79953f22ab60958a510786f2833c4d957e6432 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 4 Nov 2024 12:22:17 +0800 Subject: [PATCH] add tensorboard samples --- .github/scripts/ljspeech/TTS/run-matcha.sh | 2 +- egs/ljspeech/TTS/matcha/train.py | 38 +++++++++++++++++++--- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh index 37e1bc320..0876cb47f 100755 --- a/.github/scripts/ljspeech/TTS/run-matcha.sh +++ b/.github/scripts/ljspeech/TTS/run-matcha.sh @@ -56,7 +56,7 @@ function infer() { curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 - ./matcha/inference.py \ + ./matcha/infer.py \ --epoch 1 \ --exp-dir ./matcha/exp \ --tokens data/tokens.txt \ diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 5e713fdfd..51d6cbc7a 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -7,9 +7,10 @@ import json import logging from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import k2 +import numpy as np import torch import torch.multiprocessing as mp import torch.nn as nn @@ -343,7 +344,7 @@ def compute_validation_loss( valid_dl: torch.utils.data.DataLoader, world_size: int = 1, rank: int = 0, -) -> MetricsTracker: +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: """Run the validation process.""" model.eval() device = model.device if isinstance(model, DDP) else next(model.parameters()).device @@ -391,6 +392,23 @@ def compute_validation_loss( # summary stats tot_loss = tot_loss + loss_info + # infer for first batch: + if batch_idx == 0 and rank == 0: + inner_model = model.module if isinstance(model, DDP) else model + audio_pred, _, duration = inner_model.inference( + text=tokens[0, : tokens_lens[0].item()] + ) + audio_pred = audio_pred.data.cpu().numpy() + audio_len_pred = ( + (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + ) + assert audio_len_pred == len(audio_pred), ( + audio_len_pred, + len(audio_pred), + ) + audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() + returned_sample = (audio_pred, audio_gt) + if world_size > 1: tot_loss.reduce(device) @@ -399,7 +417,7 @@ def compute_validation_loss( params.best_valid_epoch = params.cur_epoch params.best_valid_loss = loss_value - return tot_loss + return tot_loss, returned_sample def train_one_epoch( @@ -559,7 +577,7 @@ def train_one_epoch( if params.batch_idx_train % params.valid_interval == 1: logging.info("Computing validation loss") - valid_info = compute_validation_loss( + valid_info, (speech_hat, speech) = compute_validation_loss( params=params, model=model, tokenizer=tokenizer, @@ -577,6 +595,18 @@ def train_one_epoch( valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) + tb_writer.add_audio( + "train/valid_speech_hat", + speech_hat, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/valid_speech", + speech, + params.batch_idx_train, + params.sampling_rate, + ) loss_value = tot_loss["tot_loss"] / tot_loss["samples"] params.train_loss = loss_value