minor updates

This commit is contained in:
JinZr 2024-09-06 22:05:21 +08:00
parent 4483c6e700
commit 12c7a16a5a

View File

@ -58,6 +58,13 @@ def get_parser():
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-samples",
type=int,
default=3,
help="Number of samples to generate for tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
@ -563,23 +570,24 @@ def train_one_epoch(
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
speech_hat_i = speech_hat[0]
speech_i = speech[0]
if speech_hat_i.dim() > 1:
speech_hat_i = speech_hat_i.squeeze(0)
speech_i = speech_i.squeeze(0)
tb_writer.add_audio(
"train/valid_speech_hat",
speech_hat_i,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
"train/valid_speech",
speech_i,
params.batch_idx_train,
params.sampling_rate,
)
for index in range(params.num_samples): # 3
speech_hat_i = speech_hat[index]
speech_i = speech[index]
if speech_hat_i.dim() > 1:
speech_hat_i = speech_hat_i.squeeze(0)
speech_i = speech_i.squeeze(0)
tb_writer.add_audio(
f"train/valid_speech_hat_{index}",
speech_hat_i,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
f"train/valid_speech_{index}",
speech_i,
params.batch_idx_train,
params.sampling_rate,
)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
params.train_loss = loss_value