Add num_features arg

This commit is contained in:
Rezakh20 2024-03-05 11:10:06 +03:30 committed by GitHub
parent 41a4648eb7
commit bc01398902
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -31,7 +31,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir conformer_ctc2/exp \ --exp-dir conformer_ctc2/exp \
--lang-dir data/lang_bpe_200 \ --lang-dir data/lang_bpe_200 \
--otc-token "<star>" \ --otc-token "<star>" \
--num-features 80 \ --num-features 768 \
--allow-bypass-arc true \ --allow-bypass-arc true \
--allow-self-loop-arc true \ --allow-self-loop-arc true \
--initial-bypass-weight -19 \ --initial-bypass-weight -19 \
@ -383,8 +383,8 @@ def get_params() -> AttributeDict:
- warm_step: The warm_step for Noam optimizer. - warm_step: The warm_step for Noam optimizer.
""" """
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
feature_dim = args.num_features
params = AttributeDict( params = AttributeDict(
{ {
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
@ -397,7 +397,7 @@ def get_params() -> AttributeDict:
"valid_interval": 800, # For the 100h subset, use 800 "valid_interval": 800, # For the 100h subset, use 800
"alignment_interval": 25, "alignment_interval": 25,
# parameters for conformer # parameters for conformer
"feature_dim": feature_dim, "feature_dim": args.num_features,
"subsampling_factor": 2, "subsampling_factor": 2,
"encoder_dim": 512, "encoder_dim": 512,
"nhead": 8, "nhead": 8,