Address code review

This commit is contained in:
Piotr Żelasko 2022-01-24 10:17:47 -05:00
parent 1d5fe8afa4
commit 565c1d8413

View File

@ -125,6 +125,15 @@ def get_parser():
""",
)
parser.add_argument(
"--num-decoder-layers",
type=int,
default=6,
help="""Number of decoder layer of transformer decoder.
Setting this to 0 will not create the decoder at all (pure CTC model)
""",
)
return parser
@ -203,7 +212,6 @@ def get_params() -> AttributeDict:
"use_feat_batchnorm": True,
"attention_dim": 512,
"nhead": 8,
"num_decoder_layers": 6,
# parameters for loss
"beam_size": 10,
"reduction": "sum",
@ -599,6 +607,12 @@ def run(rank, world_size, args):
"at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
"for pure CTC training when using a phone-based lang dir."
)
assert params.num_decoder_layers == 0, (
"Attention decoder training does not support phone lang dirs "
"at this time due to a missing <sos/eos> symbol. "
"Set --num-decoder-layers=0 for pure CTC training when using "
"a phone-based lang dir."
)
graph_compiler = CtcTrainingGraphCompiler(
lexicon,
device=device,