From 12c7a16a5a3eb2966469b3027c070bccf5a35805 Mon Sep 17 00:00:00 2001 From: JinZr Date: Fri, 6 Sep 2024 22:05:21 +0800 Subject: [PATCH] minor updates --- egs/libritts/CODEC/encodec/train.py | 42 +++++++++++++++++------------ 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 73d698008..842689155 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -58,6 +58,13 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--num-samples", + type=int, + default=3, + help="Number of samples to generate for tensorboard.", + ) + parser.add_argument( "--num-epochs", type=int, @@ -563,23 +570,24 @@ def train_one_epoch( valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) - speech_hat_i = speech_hat[0] - speech_i = speech[0] - if speech_hat_i.dim() > 1: - speech_hat_i = speech_hat_i.squeeze(0) - speech_i = speech_i.squeeze(0) - tb_writer.add_audio( - "train/valid_speech_hat", - speech_hat_i, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - "train/valid_speech", - speech_i, - params.batch_idx_train, - params.sampling_rate, - ) + for index in range(params.num_samples): # 3 + speech_hat_i = speech_hat[index] + speech_i = speech[index] + if speech_hat_i.dim() > 1: + speech_hat_i = speech_hat_i.squeeze(0) + speech_i = speech_i.squeeze(0) + tb_writer.add_audio( + f"train/valid_speech_hat_{index}", + speech_hat_i, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + f"train/valid_speech_{index}", + speech_i, + params.batch_idx_train, + params.sampling_rate, + ) loss_value = tot_loss["generator_loss"] / tot_loss["samples"] params.train_loss = loss_value