Make --context-size configurable.

This commit is contained in:
Fangjun Kuang 2021-12-24 14:12:32 +08:00
parent 35d63de820
commit c57798661c
4 changed files with 34 additions and 8 deletions

View File

@ -114,6 +114,14 @@ def get_parser():
help="Used only when --decoding-method is beam_search",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
@ -129,8 +137,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(),
}
)
@ -379,6 +385,8 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search":
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")

View File

@ -104,6 +104,14 @@ def get_parser():
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
@ -119,8 +127,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(),
}
)

View File

@ -110,6 +110,14 @@ def get_parser():
help="Used only when --method is beam_search",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
@ -126,8 +134,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(),
}
)

View File

@ -130,6 +130,14 @@ def get_parser():
help="The lr_factor for Noam optimizer",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
@ -196,8 +204,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"context_size": 2, # tri-gram
# parameters for Noam
"warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(),