diff --git a/.github/workflows/run-pretrained.yml b/.github/workflows/run-pretrained.yml index 6d162183d..3cdba1320 100644 --- a/.github/workflows/run-pretrained.yml +++ b/.github/workflows/run-pretrained.yml @@ -82,6 +82,7 @@ jobs: - name: Run CTC decoding shell: bash run: | + export PYTHONPATH=$PWD:PYTHONPATH cd egs/librispeech/ASR ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ @@ -94,6 +95,7 @@ jobs: - name: Run HLG decoding shell: bash run: | + export PYTHONPATH=$PWD:$PYTHONPATH cd egs/librispeech/ASR ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ @@ -101,5 +103,3 @@ jobs: ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac - - diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 5554aaa7c..d1cdfa8bb 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -96,6 +96,26 @@ def get_parser(): """, ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + return parser @@ -110,12 +130,6 @@ def get_params() -> AttributeDict: Explanation of options saved in `params`: - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" - - best_train_loss: Best training loss so far. It is used to select the model that has the lowest training loss. It is updated during the training. @@ -166,8 +180,6 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_bpe"), "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -638,6 +650,8 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) world_size = args.world_size assert world_size >= 1