mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-10 22:45:27 +00:00
training and decoding compatibility changes
This commit is contained in:
parent
130c2a59c3
commit
5400f4315d
@ -347,19 +347,19 @@ class MLSEnglishHFAsrDataModule:
|
|||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info("About to get train cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "mls_english_cuts_train.jsonl.gz"
|
self.args.manifest_dir / "mls_eng_cuts_train.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def valid_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
logging.info("About to get dev cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "mls_english_cuts_dev.jsonl.gz"
|
self.args.manifest_dir / "mls_eng_cuts_dev.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_cuts(self) -> List[CutSet]:
|
def test_cuts(self) -> List[CutSet]:
|
||||||
logging.info("About to get test cuts")
|
logging.info("About to get test cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "mls_english_cuts_test.jsonl.gz"
|
self.args.manifest_dir / "mls_eng_cuts_test.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1044,13 +1044,13 @@ def main():
|
|||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
mls_english_corpus = MLSEnglishHFAsrDataModule(args)
|
mls_english_corpus = MLSEnglishHFAsrDataModule(args)
|
||||||
mls_english_corpus.load_dataset(args.dataset_path)
|
|
||||||
|
|
||||||
# # dev_cuts = mls_english_corpus.dev_cuts()
|
# # dev_cuts = mls_english_corpus.dev_cuts()
|
||||||
# test_cuts = mls_english_corpus.test_cuts()
|
# test_cuts = mls_english_corpus.test_cuts()
|
||||||
|
|
||||||
# dev_dl = mls_english_corpus.test_dataloader()
|
# dev_dl = mls_english_corpus.test_dataloader()
|
||||||
test_dl = mls_english_corpus.test_dataloader()
|
test_cuts = mls_english_corpus.test_cuts()
|
||||||
|
test_dl = mls_english_corpus.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
test_sets = ["test"]
|
test_sets = ["test"]
|
||||||
test_dls = [test_dl]
|
test_dls = [test_dl]
|
||||||
|
|||||||
@ -1240,7 +1240,8 @@ def run(rank, world_size, args):
|
|||||||
train_dl = mls_english_corpus.train_dataloaders(
|
train_dl = mls_english_corpus.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
valid_dl = mls_english_corpus.valid_dataloader()
|
valid_cuts = mls_english_corpus.valid_cuts()
|
||||||
|
valid_dl = mls_english_corpus.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user