From 5400f4315df67ddf63f5e563f8ac2cdd1007c6ea Mon Sep 17 00:00:00 2001 From: Bailey Hirota Date: Mon, 11 Aug 2025 15:37:49 +0900 Subject: [PATCH] training and decoding compatibility changes --- egs/mls_english/ASR/local/utils/asr_datamodule.py | 6 +++--- egs/mls_english/ASR/zipformer/decode.py | 4 ++-- egs/mls_english/ASR/zipformer/train.py | 3 ++- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/egs/mls_english/ASR/local/utils/asr_datamodule.py b/egs/mls_english/ASR/local/utils/asr_datamodule.py index 6c6a1dd03..f1417c54b 100644 --- a/egs/mls_english/ASR/local/utils/asr_datamodule.py +++ b/egs/mls_english/ASR/local/utils/asr_datamodule.py @@ -347,19 +347,19 @@ class MLSEnglishHFAsrDataModule: def train_cuts(self) -> CutSet: logging.info("About to get train cuts") return load_manifest_lazy( - self.args.manifest_dir / "mls_english_cuts_train.jsonl.gz" + self.args.manifest_dir / "mls_eng_cuts_train.jsonl.gz" ) @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to get dev cuts") return load_manifest_lazy( - self.args.manifest_dir / "mls_english_cuts_dev.jsonl.gz" + self.args.manifest_dir / "mls_eng_cuts_dev.jsonl.gz" ) @lru_cache() def test_cuts(self) -> List[CutSet]: logging.info("About to get test cuts") return load_manifest_lazy( - self.args.manifest_dir / "mls_english_cuts_test.jsonl.gz" + self.args.manifest_dir / "mls_eng_cuts_test.jsonl.gz" ) diff --git a/egs/mls_english/ASR/zipformer/decode.py b/egs/mls_english/ASR/zipformer/decode.py index fc8de5d64..220cdcc9d 100755 --- a/egs/mls_english/ASR/zipformer/decode.py +++ b/egs/mls_english/ASR/zipformer/decode.py @@ -1044,13 +1044,13 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True mls_english_corpus = MLSEnglishHFAsrDataModule(args) - mls_english_corpus.load_dataset(args.dataset_path) # # dev_cuts = mls_english_corpus.dev_cuts() # test_cuts = mls_english_corpus.test_cuts() # dev_dl = mls_english_corpus.test_dataloader() - test_dl = mls_english_corpus.test_dataloader() + test_cuts = mls_english_corpus.test_cuts() + test_dl = mls_english_corpus.test_dataloaders(test_cuts) test_sets = ["test"] test_dls = [test_dl] diff --git a/egs/mls_english/ASR/zipformer/train.py b/egs/mls_english/ASR/zipformer/train.py index 9b101f1ce..63020abfb 100755 --- a/egs/mls_english/ASR/zipformer/train.py +++ b/egs/mls_english/ASR/zipformer/train.py @@ -1240,7 +1240,8 @@ def run(rank, world_size, args): train_dl = mls_english_corpus.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - valid_dl = mls_english_corpus.valid_dataloader() + valid_cuts = mls_english_corpus.valid_cuts() + valid_dl = mls_english_corpus.valid_dataloaders(valid_cuts) if not params.print_diagnostics: scan_pessimistic_batches_for_oom(