diff --git a/egs/aishell/ASR/transformer_ctc/.decode.py.swp b/egs/aishell/ASR/transformer_ctc/.decode.py.swp index d10f07cfd..685a7f1ad 100644 Binary files a/egs/aishell/ASR/transformer_ctc/.decode.py.swp and b/egs/aishell/ASR/transformer_ctc/.decode.py.swp differ diff --git a/egs/aishell/ASR/transformer_ctc/.train.py.swp b/egs/aishell/ASR/transformer_ctc/.train.py.swp index 4c6579080..54528f1be 100644 Binary files a/egs/aishell/ASR/transformer_ctc/.train.py.swp and b/egs/aishell/ASR/transformer_ctc/.train.py.swp differ diff --git a/egs/aishell/ASR/transformer_ctc/decode.py b/egs/aishell/ASR/transformer_ctc/decode.py index 06cf5e12a..6515ce88c 100755 --- a/egs/aishell/ASR/transformer_ctc/decode.py +++ b/egs/aishell/ASR/transformer_ctc/decode.py @@ -505,7 +505,8 @@ def main(): if not hasattr(HLG, "lm_scores"): HLG.lm_scores = HLG.scores.clone() - + + ''' model = Conformer( num_features=params.feature_dim, nhead=params.nhead, @@ -517,6 +518,14 @@ def main(): vgg_frontend=params.vgg_frontend, use_feat_batchnorm=params.use_feat_batchnorm, ) + ''' + model = Transformer( + num_features=params.feature_dim, + num_classes=num_classes, + use_feat_batchnorm=params.use_feat_batchnorm, + num_encoder_layers=params.num_encoder_layers, + num_decoder_layers=params.num_decoder_layers, + ) if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)