minor updates

This commit is contained in:
JinZr 2024-09-06 22:05:21 +08:00
parent 4483c6e700
commit 12c7a16a5a

View File

@ -58,6 +58,13 @@ def get_parser():
help="Should various information be logged in tensorboard.", help="Should various information be logged in tensorboard.",
) )
parser.add_argument(
"--num-samples",
type=int,
default=3,
help="Number of samples to generate for tensorboard.",
)
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs",
type=int, type=int,
@ -563,23 +570,24 @@ def train_one_epoch(
valid_info.write_summary( valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
speech_hat_i = speech_hat[0] for index in range(params.num_samples): # 3
speech_i = speech[0] speech_hat_i = speech_hat[index]
if speech_hat_i.dim() > 1: speech_i = speech[index]
speech_hat_i = speech_hat_i.squeeze(0) if speech_hat_i.dim() > 1:
speech_i = speech_i.squeeze(0) speech_hat_i = speech_hat_i.squeeze(0)
tb_writer.add_audio( speech_i = speech_i.squeeze(0)
"train/valid_speech_hat", tb_writer.add_audio(
speech_hat_i, f"train/valid_speech_hat_{index}",
params.batch_idx_train, speech_hat_i,
params.sampling_rate, params.batch_idx_train,
) params.sampling_rate,
tb_writer.add_audio( )
"train/valid_speech", tb_writer.add_audio(
speech_i, f"train/valid_speech_{index}",
params.batch_idx_train, speech_i,
params.sampling_rate, params.batch_idx_train,
) params.sampling_rate,
)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"] loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
params.train_loss = loss_value params.train_loss = loss_value