diff --git a/egs/icmcasr/ASR/zipformer/decode.py b/egs/icmcasr/ASR/zipformer/decode.py index 3531d657f..12dd61b54 100755 --- a/egs/icmcasr/ASR/zipformer/decode.py +++ b/egs/icmcasr/ASR/zipformer/decode.py @@ -2,6 +2,7 @@ # # Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, # Zengwei Yao) +# 2023 NVIDIA Corporation (Wen Ding) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -20,9 +21,9 @@ Usage: (1) greedy search ./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ + --epoch 13 \ + --avg 10 \ + --exp-dir ./zipformer/exp_maxdur500_lr0.0225/ \ --max-duration 600 \ --decoding-method greedy_search @@ -106,7 +107,7 @@ import k2 import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import ICMCAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -729,7 +730,7 @@ def save_results( ) with open(errs_filename, "w") as f: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True + f, f"{test_set_name}-{key}", results, enable_log=True, compute_CER=True ) test_set_wers[key] = wer @@ -755,7 +756,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + ICMCAsrDataModule.add_arguments(parser) LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -1014,16 +1015,16 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) + icmc = ICMCAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + test_ihm_cuts = icmc.dev_ihm_cuts() + test_shm_cuts = icmc.dev_shm_cuts() - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) + test_ihm_dl = icmc.test_dataloaders(test_ihm_cuts) + test_shm_dl = icmc.test_dataloaders(test_shm_cuts) - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + test_sets = ["test-ihm", "test-shm"] + test_dl = [test_ihm_dl, test_shm_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset(