diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 73d698008..842689155 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -58,6 +58,13 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--num-samples", + type=int, + default=3, + help="Number of samples to generate for tensorboard.", + ) + parser.add_argument( "--num-epochs", type=int, @@ -563,23 +570,24 @@ def train_one_epoch( valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) - speech_hat_i = speech_hat[0] - speech_i = speech[0] - if speech_hat_i.dim() > 1: - speech_hat_i = speech_hat_i.squeeze(0) - speech_i = speech_i.squeeze(0) - tb_writer.add_audio( - "train/valid_speech_hat", - speech_hat_i, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - "train/valid_speech", - speech_i, - params.batch_idx_train, - params.sampling_rate, - ) + for index in range(params.num_samples): # 3 + speech_hat_i = speech_hat[index] + speech_i = speech[index] + if speech_hat_i.dim() > 1: + speech_hat_i = speech_hat_i.squeeze(0) + speech_i = speech_i.squeeze(0) + tb_writer.add_audio( + f"train/valid_speech_hat_{index}", + speech_hat_i, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + f"train/valid_speech_{index}", + speech_i, + params.batch_idx_train, + params.sampling_rate, + ) loss_value = tot_loss["generator_loss"] / tot_loss["samples"] params.train_loss = loss_value