Update train.py

This commit is contained in:
zr_jin 2024-10-22 21:29:46 +08:00
parent 5635ba0751
commit 442a745d8e

View File

@ -170,6 +170,12 @@ def get_parser():
help="The chunk size of the biomarker (in second).",
)
parser.add_argument(
"--n-q",
type=int,
help="The number of quantization levels.",
)
return parser
@ -300,7 +306,7 @@ def get_model(params: AttributeDict) -> nn.Module:
generator_params = {
"generator_n_filters": 32,
"dimension": 512,
"ratios": [8, 5, 4, 2],
"ratios": [8, 6, 4, 2],
"target_bandwidths": [1.5, 3, 6, 12, 24],
"bins": 1024,
}
@ -315,6 +321,9 @@ def get_model(params: AttributeDict) -> nn.Module:
"target_bw": 6,
}
logging.info(f"Generator params: {generator_params}")
logging.info(f"Discriminator params: {discriminator_params}")
logging.info(f"Inference params: {inference_params}")
params.update(generator_params)
params.update(discriminator_params)
params.update(inference_params)
@ -324,7 +333,7 @@ def get_model(params: AttributeDict) -> nn.Module:
1000
* params.target_bandwidths[-1]
// (math.ceil(params.sampling_rate / hop_length) * 10)
)
) if params.n_q is None else params.n_q
encoder = SEANetEncoder(
n_filters=params.generator_n_filters,