updated train.py

This commit is contained in:
JinZr 2024-10-09 14:12:41 +08:00
parent 2356621059
commit df87a0fe2c

View File

@ -527,6 +527,7 @@ def train_one_epoch(
+ params.lambda_feat * feature_loss
+ params.lambda_com * commit_loss
)
loss_info["generator_loss"] = gen_loss
for k, v in stats_g.items():
if "returned_sample" not in k:
loss_info[k] = v * batch_size
@ -737,6 +738,7 @@ def compute_validation_loss(
+ disc_scale_fake_adv_loss
) * d_weight
assert disc_loss.requires_grad is False
loss_info["discriminator_loss"] = disc_loss
for k, v in stats_d.items():
loss_info[k] = v * batch_size
@ -778,6 +780,7 @@ def compute_validation_loss(
+ params.lambda_com * commit_loss
)
assert gen_loss.requires_grad is False
loss_info["generator_loss"] = gen_loss
for k, v in stats_g.items():
if "returned_sample" not in k:
loss_info[k] = v * batch_size