mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
minor updates
This commit is contained in:
parent
4483c6e700
commit
12c7a16a5a
@ -58,6 +58,13 @@ def get_parser():
|
|||||||
help="Should various information be logged in tensorboard.",
|
help="Should various information be logged in tensorboard.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-samples",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Number of samples to generate for tensorboard.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs",
|
"--num-epochs",
|
||||||
type=int,
|
type=int,
|
||||||
@ -563,19 +570,20 @@ def train_one_epoch(
|
|||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
speech_hat_i = speech_hat[0]
|
for index in range(params.num_samples): # 3
|
||||||
speech_i = speech[0]
|
speech_hat_i = speech_hat[index]
|
||||||
|
speech_i = speech[index]
|
||||||
if speech_hat_i.dim() > 1:
|
if speech_hat_i.dim() > 1:
|
||||||
speech_hat_i = speech_hat_i.squeeze(0)
|
speech_hat_i = speech_hat_i.squeeze(0)
|
||||||
speech_i = speech_i.squeeze(0)
|
speech_i = speech_i.squeeze(0)
|
||||||
tb_writer.add_audio(
|
tb_writer.add_audio(
|
||||||
"train/valid_speech_hat",
|
f"train/valid_speech_hat_{index}",
|
||||||
speech_hat_i,
|
speech_hat_i,
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
params.sampling_rate,
|
params.sampling_rate,
|
||||||
)
|
)
|
||||||
tb_writer.add_audio(
|
tb_writer.add_audio(
|
||||||
"train/valid_speech",
|
f"train/valid_speech_{index}",
|
||||||
speech_i,
|
speech_i,
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
params.sampling_rate,
|
params.sampling_rate,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user