Support training for aec-iva and gss data

This commit is contained in:
wd929 2023-12-27 17:00:45 +08:00
parent 77d8a15288
commit 7310489dc9
2 changed files with 44 additions and 19 deletions

View File

@ -101,6 +101,12 @@ class ICMCAsrDataModule:
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--manifest-aec-iva-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with aec iva data train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
@ -400,51 +406,67 @@ class ICMCAsrDataModule:
)
return test_dl
@lru_cache()
def train_ihm_cuts(self) -> CutSet:
logging.info("About to get train-ihm cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_train_ihm.jsonl.gz"
)
@lru_cache()
def train_ihm_rvb_cuts(self) -> CutSet:
logging.info("About to get train-ihm-rvb cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_train_ihm_rvb.jsonl.gz"
)
@lru_cache()
def train_shm_cuts(self) -> CutSet:
logging.info("About to get train-shm cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_train_sdm.jsonl.gz"
)
def train_gss_cuts(self) -> CutSet:
logging.info("About to get train-gss cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_train_gss.jsonl.gz"
)
def train_aec_iva_cuts(self) -> CutSet:
logging.info("About to get train-aec_iva cuts")
return load_manifest_lazy(
self.args.manifest_aec_iva_dir / "icmcasr-aec-iva_cuts_train_aec_iva.jsonl.gz"
@lru_cache()
def dev_ihm_cuts(self) -> CutSet:
logging.info("About to get dev-ihm cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_dev_ihm.jsonl.gz"
)
@lru_cache()
def dev_shm_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz"
)
def dev_aec_iva_cuts(self) -> CutSet:
logging.info("About to get aec iva dev cuts")
return load_manifest_lazy(
self.args.manifest_aec_iva_dir / "icmcasr-aec-iva_cuts_train_aec_iva.jsonl.gz"
)
def dev_gss_cuts(self) -> CutSet:
logging.info("About to get dev-gss cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_dev_gss.jsonl.gz"
)
def test_eval_track1_gss(self) -> CutSet:
logging.info("About to get eval1-gss cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_eval_track1_gss.jsonl.gz"
)
def test_eval_track1_sdm(self) -> CutSet:
logging.info("About to get eval1-sdm cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_eval_track1_sdm.jsonl.gz"
)
def test_eval_track1_aec_iva(self) -> CutSet:
logging.info("About to get aec iva dev cuts")
return load_manifest_lazy(
self.args.manifest_aec_iva_dir / "cuts.jsonl.gz"
# @lru_cache()
# def test_clean_cuts(self) -> CutSet:
# logging.info("About to get test-clean cuts")
# return load_manifest_lazy(
# self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
# )
# @lru_cache()
# def test_other_cuts(self) -> CutSet:
# logging.info("About to get test-other cuts")
# return load_manifest_lazy(
# self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
# )

View File

@ -1179,6 +1179,8 @@ def run(rank, world_size, args):
if params.full_data:
train_cuts += icmc.train_ihm_rvb_cuts()
train_cuts += icmc.train_shm_cuts()
train_cuts += icmc.train_aec_iva_cuts()
train_cuts += icmc.train_gss_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
@ -1231,7 +1233,8 @@ def run(rank, world_size, args):
)
valid_cuts = icmc.dev_ihm_cuts()
# valid_cuts += librispeech.dev_other_cuts()
valid_cuts += icmc.dev_gss_cuts()
valid_cuts += icmc.dev_aec_iva_cuts()
valid_dl = icmc.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: