From 565c1d8413fa9471eed3574eae5fa18eff31c1ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 24 Jan 2022 10:17:47 -0500 Subject: [PATCH] Address code review --- egs/librispeech/ASR/conformer_ctc/train.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 13a546149..e7f0402db 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -125,6 +125,15 @@ def get_parser(): """, ) + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + return parser @@ -203,7 +212,6 @@ def get_params() -> AttributeDict: "use_feat_batchnorm": True, "attention_dim": 512, "nhead": 8, - "num_decoder_layers": 6, # parameters for loss "beam_size": 10, "reduction": "sum", @@ -599,6 +607,12 @@ def run(rank, world_size, args): "at this time due to a missing symbol. Set --att-rate=0 " "for pure CTC training when using a phone-based lang dir." ) + assert params.num_decoder_layers == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. " + "Set --num-decoder-layers=0 for pure CTC training when using " + "a phone-based lang dir." + ) graph_compiler = CtcTrainingGraphCompiler( lexicon, device=device,