Update train.py

This commit is contained in:
jinzr 2023-11-30 22:47:59 +08:00
parent 615a5e8d46
commit cbf8b2d36c

View File

@ -452,7 +452,7 @@ def train_one_epoch(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sid=speakers,
sids=speakers,
forward_generator=False,
)
for k, v in stats_d.items():
@ -471,7 +471,7 @@ def train_one_epoch(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sid=speakers,
sids=speakers,
forward_generator=True,
return_sample=params.batch_idx_train % params.log_interval == 0,
)
@ -652,7 +652,7 @@ def compute_validation_loss(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sid=speakers,
sids=speakers,
forward_generator=False,
)
assert loss_d.requires_grad is False
@ -667,7 +667,7 @@ def compute_validation_loss(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sid=speakers,
sids=speakers,
forward_generator=True,
)
assert loss_g.requires_grad is False
@ -742,7 +742,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sid=speakers,
sids=speakers,
forward_generator=False,
)
optimizer_d.zero_grad()
@ -756,7 +756,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
sid=speakers,
sids=speakers,
forward_generator=True,
)
optimizer_g.zero_grad()