mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Update train.py
add num_features to input args
This commit is contained in:
parent
267a36e6ef
commit
648495d555
@ -31,6 +31,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--exp-dir conformer_ctc2/exp \
|
||||
--lang-dir data/lang_bpe_200 \
|
||||
--otc-token "<star>" \
|
||||
--num_features 80
|
||||
--allow-bypass-arc true \
|
||||
--allow-self-loop-arc true \
|
||||
--initial-bypass-weight -19 \
|
||||
@ -160,6 +161,14 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_features",
|
||||
type=int,
|
||||
default=768,
|
||||
help="""Number of features extracted in feature extraction stage.last dimension of feature vector.
|
||||
80 when using fbank features and 768 or 1024 whn using wave2vec""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--initial-lr",
|
||||
type=float,
|
||||
@ -373,6 +382,9 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- warm_step: The warm_step for Noam optimizer.
|
||||
"""
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
feature_dim = args.num_features
|
||||
params = AttributeDict(
|
||||
{
|
||||
"best_train_loss": float("inf"),
|
||||
@ -385,7 +397,7 @@ def get_params() -> AttributeDict:
|
||||
"valid_interval": 800, # For the 100h subset, use 800
|
||||
"alignment_interval": 25,
|
||||
# parameters for conformer
|
||||
"feature_dim": 80, # when using fbank features for training
|
||||
"feature_dim": feature_dim,
|
||||
"subsampling_factor": 2,
|
||||
"encoder_dim": 512,
|
||||
"nhead": 8,
|
||||
|
Loading…
x
Reference in New Issue
Block a user