mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Pass model parameters from the command line.
This commit is contained in:
parent
85ac3a8000
commit
d79f5fecf7
@ -18,36 +18,36 @@
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
./pruned_transducer_stateless4/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
./pruned_transducer_stateless4/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
./pruned_transducer_stateless4/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
./pruned_transducer_stateless4/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||
--max-duration 1500 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
@ -74,7 +74,7 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import get_params, get_transducer_model
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -124,7 +124,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless4/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
@ -197,6 +197,8 @@ def get_parser():
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ To run this file, do:
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
def test_model_1():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
@ -39,8 +39,26 @@ def test_model():
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
||||
# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
|
||||
def test_model_M():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = 18
|
||||
params.dim_feedforward = 1024
|
||||
params.encoder_dim = 256
|
||||
params.nhead = 4
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
model = get_transducer_model(params)
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
||||
def main():
|
||||
test_model()
|
||||
# test_model_1()
|
||||
test_model_M()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -83,6 +83,53 @@ LRSchedulerType = Union[
|
||||
]
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
type=int,
|
||||
default=24,
|
||||
help="Number of conformer encoder layers..",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dim-feedforward",
|
||||
type=int,
|
||||
default=1536,
|
||||
help="Feedforward dimension of the conformer encoder layer.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nhead",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of attention heads in the conformer encoder layer.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dim",
|
||||
type=int,
|
||||
default=384,
|
||||
help="Attention dimension in the conformer encoder layer.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Embedding dimension in the decoder model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="""Dimension used in the joiner model.
|
||||
Outputs from the encoder and decoder model are projected
|
||||
to this dimension before adding.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -156,15 +203,16 @@ def get_parser():
|
||||
"--initial-lr",
|
||||
type=float,
|
||||
default=0.003,
|
||||
help="The initial learning rate. This value should not need to be changed.",
|
||||
help="The initial learning rate. This value should not need "
|
||||
"to be changed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-batches",
|
||||
type=float,
|
||||
default=5000,
|
||||
help="""Number of steps that affects how rapidly the learning rate decreases.
|
||||
We suggest not to change this.""",
|
||||
help="""Number of steps that affects how rapidly the learning rate
|
||||
decreases. We suggest not to change this.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -262,6 +310,8 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -322,14 +372,6 @@ def get_params() -> AttributeDict:
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"encoder_dim": 384,
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 1536,
|
||||
"num_encoder_layers": 24,
|
||||
# parameters for decoder
|
||||
"decoder_dim": 512,
|
||||
# parameters for joiner
|
||||
"joiner_dim": 512,
|
||||
# parameters for Noam
|
||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||
"env_info": get_env_info(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user