From 729e86edb906d51c82fc1ecbf17202fa29507218 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 4 Nov 2024 14:26:51 +0800 Subject: [PATCH] Update train.py --- egs/ljspeech/TTS/matcha/train.py | 37 ++++---------------------------- 1 file changed, 4 insertions(+), 33 deletions(-) diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 897abbb51..8ad307fda 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -7,7 +7,7 @@ import json import logging from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import k2 import numpy as np @@ -344,7 +344,7 @@ def compute_validation_loss( valid_dl: torch.utils.data.DataLoader, world_size: int = 1, rank: int = 0, -) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: +) -> MetricsTracker: """Run the validation process.""" model.eval() device = model.device if isinstance(model, DDP) else next(model.parameters()).device @@ -392,23 +392,6 @@ def compute_validation_loss( # summary stats tot_loss = tot_loss + loss_info - # infer for first batch: - if batch_idx == 0 and rank == 0: - inner_model = model.module if isinstance(model, DDP) else model - audio_pred, _, duration = inner_model.inference( - text=tokens[0, : tokens_lens[0].item()] - ) - audio_pred = audio_pred.data.cpu().numpy() - audio_len_pred = ( - (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() - ) - assert audio_len_pred == len(audio_pred), ( - audio_len_pred, - len(audio_pred), - ) - audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() - returned_sample = (audio_pred, audio_gt) - if world_size > 1: tot_loss.reduce(device) @@ -417,7 +400,7 @@ def compute_validation_loss( params.best_valid_epoch = params.cur_epoch params.best_valid_loss = loss_value - return tot_loss, returned_sample + return tot_loss def train_one_epoch( @@ -577,7 +560,7 @@ def train_one_epoch( if params.batch_idx_train % params.valid_interval == 1: logging.info("Computing validation loss") - valid_info, (speech_hat, speech) = compute_validation_loss( + valid_info = compute_validation_loss( params=params, model=model, tokenizer=tokenizer, @@ -595,18 +578,6 @@ def train_one_epoch( valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) - tb_writer.add_audio( - "train/valid_speech_hat", - speech_hat, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - "train/valid_speech", - speech, - params.batch_idx_train, - params.sampling_rate, - ) loss_value = tot_loss["tot_loss"] / tot_loss["samples"] params.train_loss = loss_value