minor updates

This commit is contained in:
JinZr 2024-09-06 22:05:21 +08:00
parent 4483c6e700
commit 12c7a16a5a

View File

@ -58,6 +58,13 @@ def get_parser():
help="Should various information be logged in tensorboard.", help="Should various information be logged in tensorboard.",
) )
parser.add_argument(
"--num-samples",
type=int,
default=3,
help="Number of samples to generate for tensorboard.",
)
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs",
type=int, type=int,
@ -563,19 +570,20 @@ def train_one_epoch(
valid_info.write_summary( valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
speech_hat_i = speech_hat[0] for index in range(params.num_samples): # 3
speech_i = speech[0] speech_hat_i = speech_hat[index]
speech_i = speech[index]
if speech_hat_i.dim() > 1: if speech_hat_i.dim() > 1:
speech_hat_i = speech_hat_i.squeeze(0) speech_hat_i = speech_hat_i.squeeze(0)
speech_i = speech_i.squeeze(0) speech_i = speech_i.squeeze(0)
tb_writer.add_audio( tb_writer.add_audio(
"train/valid_speech_hat", f"train/valid_speech_hat_{index}",
speech_hat_i, speech_hat_i,
params.batch_idx_train, params.batch_idx_train,
params.sampling_rate, params.sampling_rate,
) )
tb_writer.add_audio( tb_writer.add_audio(
"train/valid_speech", f"train/valid_speech_{index}",
speech_i, speech_i,
params.batch_idx_train, params.batch_idx_train,
params.sampling_rate, params.sampling_rate,