mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
Update train.py
This commit is contained in:
parent
615a5e8d46
commit
cbf8b2d36c
@ -452,7 +452,7 @@ def train_one_epoch(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sid=speakers,
|
||||
sids=speakers,
|
||||
forward_generator=False,
|
||||
)
|
||||
for k, v in stats_d.items():
|
||||
@ -471,7 +471,7 @@ def train_one_epoch(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sid=speakers,
|
||||
sids=speakers,
|
||||
forward_generator=True,
|
||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||
)
|
||||
@ -652,7 +652,7 @@ def compute_validation_loss(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sid=speakers,
|
||||
sids=speakers,
|
||||
forward_generator=False,
|
||||
)
|
||||
assert loss_d.requires_grad is False
|
||||
@ -667,7 +667,7 @@ def compute_validation_loss(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sid=speakers,
|
||||
sids=speakers,
|
||||
forward_generator=True,
|
||||
)
|
||||
assert loss_g.requires_grad is False
|
||||
@ -742,7 +742,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sid=speakers,
|
||||
sids=speakers,
|
||||
forward_generator=False,
|
||||
)
|
||||
optimizer_d.zero_grad()
|
||||
@ -756,7 +756,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
sid=speakers,
|
||||
sids=speakers,
|
||||
forward_generator=True,
|
||||
)
|
||||
optimizer_g.zero_grad()
|
||||
|
Loading…
x
Reference in New Issue
Block a user