Update train.py

This commit is contained in:
zr_jin 2024-10-22 21:29:46 +08:00
parent 5635ba0751
commit 442a745d8e

View File

@ -170,6 +170,12 @@ def get_parser():
help="The chunk size of the biomarker (in second).", help="The chunk size of the biomarker (in second).",
) )
parser.add_argument(
"--n-q",
type=int,
help="The number of quantization levels.",
)
return parser return parser
@ -300,7 +306,7 @@ def get_model(params: AttributeDict) -> nn.Module:
generator_params = { generator_params = {
"generator_n_filters": 32, "generator_n_filters": 32,
"dimension": 512, "dimension": 512,
"ratios": [8, 5, 4, 2], "ratios": [8, 6, 4, 2],
"target_bandwidths": [1.5, 3, 6, 12, 24], "target_bandwidths": [1.5, 3, 6, 12, 24],
"bins": 1024, "bins": 1024,
} }
@ -315,6 +321,9 @@ def get_model(params: AttributeDict) -> nn.Module:
"target_bw": 6, "target_bw": 6,
} }
logging.info(f"Generator params: {generator_params}")
logging.info(f"Discriminator params: {discriminator_params}")
logging.info(f"Inference params: {inference_params}")
params.update(generator_params) params.update(generator_params)
params.update(discriminator_params) params.update(discriminator_params)
params.update(inference_params) params.update(inference_params)
@ -324,7 +333,7 @@ def get_model(params: AttributeDict) -> nn.Module:
1000 1000
* params.target_bandwidths[-1] * params.target_bandwidths[-1]
// (math.ceil(params.sampling_rate / hop_length) * 10) // (math.ceil(params.sampling_rate / hop_length) * 10)
) ) if params.n_q is None else params.n_q
encoder = SEANetEncoder( encoder = SEANetEncoder(
n_filters=params.generator_n_filters, n_filters=params.generator_n_filters,