black formatted

This commit is contained in:
JinZr 2024-10-08 13:12:12 +08:00
parent 156af46a6e
commit 43267e3e29

View File

@ -289,9 +289,17 @@ def main():
logging.info(f"Number of parameters in decoder: {num_param_d}") logging.info(f"Number of parameters in decoder: {num_param_d}")
num_param_q = sum([p.numel() for p in quantizer.parameters()]) num_param_q = sum([p.numel() for p in quantizer.parameters()])
logging.info(f"Number of parameters in quantizer: {num_param_q}") logging.info(f"Number of parameters in quantizer: {num_param_q}")
num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()]) if multi_scale_discriminator is not None else 0 num_param_ds = (
sum([p.numel() for p in multi_scale_discriminator.parameters()])
if multi_scale_discriminator is not None
else 0
)
logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}") logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}")
num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()]) if multi_period_discriminator is not None else 0 num_param_dp = (
sum([p.numel() for p in multi_period_discriminator.parameters()])
if multi_period_discriminator is not None
else 0
)
logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}") logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}")
num_param_dstft = sum( num_param_dstft = sum(
[p.numel() for p in multi_scale_stft_discriminator.parameters()] [p.numel() for p in multi_scale_stft_discriminator.parameters()]