training and decoding compatibility changes

This commit is contained in:
Bailey Hirota 2025-08-11 15:37:49 +09:00
parent 130c2a59c3
commit 5400f4315d
3 changed files with 7 additions and 6 deletions

View File

@ -347,19 +347,19 @@ class MLSEnglishHFAsrDataModule:
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train cuts")
return load_manifest_lazy( return load_manifest_lazy(
self.args.manifest_dir / "mls_english_cuts_train.jsonl.gz" self.args.manifest_dir / "mls_eng_cuts_train.jsonl.gz"
) )
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get dev cuts")
return load_manifest_lazy( return load_manifest_lazy(
self.args.manifest_dir / "mls_english_cuts_dev.jsonl.gz" self.args.manifest_dir / "mls_eng_cuts_dev.jsonl.gz"
) )
@lru_cache() @lru_cache()
def test_cuts(self) -> List[CutSet]: def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts") logging.info("About to get test cuts")
return load_manifest_lazy( return load_manifest_lazy(
self.args.manifest_dir / "mls_english_cuts_test.jsonl.gz" self.args.manifest_dir / "mls_eng_cuts_test.jsonl.gz"
) )

View File

@ -1044,13 +1044,13 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
mls_english_corpus = MLSEnglishHFAsrDataModule(args) mls_english_corpus = MLSEnglishHFAsrDataModule(args)
mls_english_corpus.load_dataset(args.dataset_path)
# # dev_cuts = mls_english_corpus.dev_cuts() # # dev_cuts = mls_english_corpus.dev_cuts()
# test_cuts = mls_english_corpus.test_cuts() # test_cuts = mls_english_corpus.test_cuts()
# dev_dl = mls_english_corpus.test_dataloader() # dev_dl = mls_english_corpus.test_dataloader()
test_dl = mls_english_corpus.test_dataloader() test_cuts = mls_english_corpus.test_cuts()
test_dl = mls_english_corpus.test_dataloaders(test_cuts)
test_sets = ["test"] test_sets = ["test"]
test_dls = [test_dl] test_dls = [test_dl]

View File

@ -1240,7 +1240,8 @@ def run(rank, world_size, args):
train_dl = mls_english_corpus.train_dataloaders( train_dl = mls_english_corpus.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
valid_dl = mls_english_corpus.valid_dataloader() valid_cuts = mls_english_corpus.valid_cuts()
valid_dl = mls_english_corpus.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(