mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
updated train.py
This commit is contained in:
parent
2356621059
commit
df87a0fe2c
@ -527,6 +527,7 @@ def train_one_epoch(
|
|||||||
+ params.lambda_feat * feature_loss
|
+ params.lambda_feat * feature_loss
|
||||||
+ params.lambda_com * commit_loss
|
+ params.lambda_com * commit_loss
|
||||||
)
|
)
|
||||||
|
loss_info["generator_loss"] = gen_loss
|
||||||
for k, v in stats_g.items():
|
for k, v in stats_g.items():
|
||||||
if "returned_sample" not in k:
|
if "returned_sample" not in k:
|
||||||
loss_info[k] = v * batch_size
|
loss_info[k] = v * batch_size
|
||||||
@ -737,6 +738,7 @@ def compute_validation_loss(
|
|||||||
+ disc_scale_fake_adv_loss
|
+ disc_scale_fake_adv_loss
|
||||||
) * d_weight
|
) * d_weight
|
||||||
assert disc_loss.requires_grad is False
|
assert disc_loss.requires_grad is False
|
||||||
|
loss_info["discriminator_loss"] = disc_loss
|
||||||
for k, v in stats_d.items():
|
for k, v in stats_d.items():
|
||||||
loss_info[k] = v * batch_size
|
loss_info[k] = v * batch_size
|
||||||
|
|
||||||
@ -778,6 +780,7 @@ def compute_validation_loss(
|
|||||||
+ params.lambda_com * commit_loss
|
+ params.lambda_com * commit_loss
|
||||||
)
|
)
|
||||||
assert gen_loss.requires_grad is False
|
assert gen_loss.requires_grad is False
|
||||||
|
loss_info["generator_loss"] = gen_loss
|
||||||
for k, v in stats_g.items():
|
for k, v in stats_g.items():
|
||||||
if "returned_sample" not in k:
|
if "returned_sample" not in k:
|
||||||
loss_info[k] = v * batch_size
|
loss_info[k] = v * batch_size
|
||||||
|
Loading…
x
Reference in New Issue
Block a user