mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Update train.py
This commit is contained in:
parent
5635ba0751
commit
442a745d8e
@ -170,6 +170,12 @@ def get_parser():
|
||||
help="The chunk size of the biomarker (in second).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--n-q",
|
||||
type=int,
|
||||
help="The number of quantization levels.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -300,7 +306,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
generator_params = {
|
||||
"generator_n_filters": 32,
|
||||
"dimension": 512,
|
||||
"ratios": [8, 5, 4, 2],
|
||||
"ratios": [8, 6, 4, 2],
|
||||
"target_bandwidths": [1.5, 3, 6, 12, 24],
|
||||
"bins": 1024,
|
||||
}
|
||||
@ -315,6 +321,9 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
"target_bw": 6,
|
||||
}
|
||||
|
||||
logging.info(f"Generator params: {generator_params}")
|
||||
logging.info(f"Discriminator params: {discriminator_params}")
|
||||
logging.info(f"Inference params: {inference_params}")
|
||||
params.update(generator_params)
|
||||
params.update(discriminator_params)
|
||||
params.update(inference_params)
|
||||
@ -324,7 +333,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
1000
|
||||
* params.target_bandwidths[-1]
|
||||
// (math.ceil(params.sampling_rate / hop_length) * 10)
|
||||
)
|
||||
) if params.n_q is None else params.n_q
|
||||
|
||||
encoder = SEANetEncoder(
|
||||
n_filters=params.generator_n_filters,
|
||||
|
Loading…
x
Reference in New Issue
Block a user