Pass model parameters from the command line.

This commit is contained in:
Fangjun Kuang 2022-04-25 17:26:43 +08:00
parent 85ac3a8000
commit d79f5fecf7
3 changed files with 85 additions and 23 deletions

View File

@ -18,36 +18,36 @@
"""
Usage:
(1) greedy search
./pruned_transducer_stateless2/decode.py \
./pruned_transducer_stateless4/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./pruned_transducer_stateless2/decode.py \
./pruned_transducer_stateless4/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless2/decode.py \
./pruned_transducer_stateless4/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless2/decode.py \
./pruned_transducer_stateless4/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
@ -74,7 +74,7 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -124,7 +124,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="pruned_transducer_stateless4/exp",
help="The experiment dir",
)
@ -197,6 +197,8 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
add_model_arguments(parser)
return parser

View File

@ -26,7 +26,7 @@ To run this file, do:
from train import get_params, get_transducer_model
def test_model():
def test_model_1():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
@ -39,8 +39,26 @@ def test_model():
print(f"Number of model parameters: {num_param}")
# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
def test_model_M():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = 18
params.dim_feedforward = 1024
params.encoder_dim = 256
params.nhead = 4
params.decoder_dim = 512
params.joiner_dim = 512
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
def main():
test_model()
# test_model_1()
test_model_M()
if __name__ == "__main__":

View File

@ -83,6 +83,53 @@ LRSchedulerType = Union[
]
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-encoder-layers",
type=int,
default=24,
help="Number of conformer encoder layers..",
)
parser.add_argument(
"--dim-feedforward",
type=int,
default=1536,
help="Feedforward dimension of the conformer encoder layer.",
)
parser.add_argument(
"--nhead",
type=int,
default=8,
help="Number of attention heads in the conformer encoder layer.",
)
parser.add_argument(
"--encoder-dim",
type=int,
default=384,
help="Attention dimension in the conformer encoder layer.",
)
parser.add_argument(
"--decoder-dim",
type=int,
default=512,
help="Embedding dimension in the decoder model.",
)
parser.add_argument(
"--joiner-dim",
type=int,
default=512,
help="""Dimension used in the joiner model.
Outputs from the encoder and decoder model are projected
to this dimension before adding.
""",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -156,15 +203,16 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
help="The initial learning rate. This value should not need to be changed.",
help="The initial learning rate. This value should not need "
"to be changed.",
)
parser.add_argument(
"--lr-batches",
type=float,
default=5000,
help="""Number of steps that affects how rapidly the learning rate decreases.
We suggest not to change this.""",
help="""Number of steps that affects how rapidly the learning rate
decreases. We suggest not to change this.""",
)
parser.add_argument(
@ -262,6 +310,8 @@ def get_parser():
help="Whether to use half precision training.",
)
add_model_arguments(parser)
return parser
@ -322,14 +372,6 @@ def get_params() -> AttributeDict:
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"encoder_dim": 384,
"nhead": 8,
"dim_feedforward": 1536,
"num_encoder_layers": 24,
# parameters for decoder
"decoder_dim": 512,
# parameters for joiner
"joiner_dim": 512,
# parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(),