mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
updated train.py
This commit is contained in:
parent
2356621059
commit
df87a0fe2c
@ -527,6 +527,7 @@ def train_one_epoch(
|
||||
+ params.lambda_feat * feature_loss
|
||||
+ params.lambda_com * commit_loss
|
||||
)
|
||||
loss_info["generator_loss"] = gen_loss
|
||||
for k, v in stats_g.items():
|
||||
if "returned_sample" not in k:
|
||||
loss_info[k] = v * batch_size
|
||||
@ -737,6 +738,7 @@ def compute_validation_loss(
|
||||
+ disc_scale_fake_adv_loss
|
||||
) * d_weight
|
||||
assert disc_loss.requires_grad is False
|
||||
loss_info["discriminator_loss"] = disc_loss
|
||||
for k, v in stats_d.items():
|
||||
loss_info[k] = v * batch_size
|
||||
|
||||
@ -778,6 +780,7 @@ def compute_validation_loss(
|
||||
+ params.lambda_com * commit_loss
|
||||
)
|
||||
assert gen_loss.requires_grad is False
|
||||
loss_info["generator_loss"] = gen_loss
|
||||
for k, v in stats_g.items():
|
||||
if "returned_sample" not in k:
|
||||
loss_info[k] = v * batch_size
|
||||
|
Loading…
x
Reference in New Issue
Block a user