Fix CI errors.

This commit is contained in:
Fangjun Kuang 2021-10-14 21:04:52 +08:00
parent 493c8812fd
commit 2de12b195e
2 changed files with 24 additions and 10 deletions

View File

@ -82,6 +82,7 @@ jobs:
- name: Run CTC decoding
shell: bash
run: |
export PYTHONPATH=$PWD:PYTHONPATH
cd egs/librispeech/ASR
./conformer_ctc/pretrained.py \
--checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \
@ -94,6 +95,7 @@ jobs:
- name: Run HLG decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
cd egs/librispeech/ASR
./conformer_ctc/pretrained.py \
--checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \
@ -101,5 +103,3 @@ jobs:
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac

View File

@ -96,6 +96,26 @@ def get_parser():
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
return parser
@ -110,12 +130,6 @@ def get_params() -> AttributeDict:
Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lang_dir: It contains language related input files such as
"lexicon.txt"
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
@ -166,8 +180,6 @@ def get_params() -> AttributeDict:
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
@ -638,6 +650,8 @@ def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
world_size = args.world_size
assert world_size >= 1