minor updates

This commit is contained in:
jinzr 2023-10-18 15:07:12 +08:00
parent 5340633654
commit f3f0dfc52d
4 changed files with 30 additions and 1 deletions

View File

@ -77,6 +77,12 @@ def compute_fbank_ami():
prefix="ami-sdm",
suffix="jsonl.gz",
)
manifests_sdm = read_manifests_if_cached(
dataset_parts=["dev", "test"],
output_dir=src_dir,
prefix="ami-mdm",
suffix="jsonl.gz",
)
# For GSS we already have cuts so we read them directly.
manifests_gss = read_manifests_if_cached(
dataset_parts=["train", "dev", "test"],

View File

@ -294,7 +294,6 @@ class AmiAsrDataModule:
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
@ -399,6 +398,12 @@ class AmiAsrDataModule:
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz")
return cs.filter(self.remove_short_cuts)
@lru_cache()
def dev_mdm_cuts(self) -> CutSet:
logging.info("About to get AMI MDM dev cuts")
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_mdm.jsonl.gz")
return cs.filter(self.remove_short_cuts)
@lru_cache()
def dev_gss_cuts(self) -> CutSet:
if not (self.args.manifest_dir / "cuts_dev_gss.jsonl.gz").exists():
@ -420,6 +425,12 @@ class AmiAsrDataModule:
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz")
return cs.filter(self.remove_short_cuts)
@lru_cache()
def test_mdm_cuts(self) -> CutSet:
logging.info("About to get AMI MDM test cuts")
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_mdm.jsonl.gz")
return cs.filter(self.remove_short_cuts)
@lru_cache()
def test_gss_cuts(self) -> CutSet:
if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists():

View File

@ -691,6 +691,8 @@ def main():
test_ihm_cuts = ami.test_ihm_cuts()
dev_sdm_cuts = ami.dev_sdm_cuts()
test_sdm_cuts = ami.test_sdm_cuts()
dev_mdm_cuts = ami.dev_mdm_cuts()
test_mdm_cuts = ami.test_mdm_cuts()
dev_gss_cuts = ami.dev_gss_cuts()
test_gss_cuts = ami.test_gss_cuts()
@ -698,6 +700,8 @@ def main():
test_ihm_dl = ami.test_dataloaders(test_ihm_cuts)
dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts)
test_sdm_dl = ami.test_dataloaders(test_sdm_cuts)
dev_mdm_dl = ami.test_dataloaders(dev_mdm_cuts)
test_mdm_dl = ami.test_dataloaders(test_mdm_cuts)
if dev_gss_cuts is not None:
dev_gss_dl = ami.test_dataloaders(dev_gss_cuts)
if test_gss_cuts is not None:
@ -708,6 +712,8 @@ def main():
"test_ihm": (test_ihm_dl, test_ihm_cuts),
"dev_sdm": (dev_sdm_dl, dev_sdm_cuts),
"test_sdm": (test_sdm_dl, test_sdm_cuts),
"dev_mdm": (dev_mdm_dl, dev_mdm_cuts),
"test_mdm": (test_mdm_dl, test_mdm_cuts),
}
if dev_gss_cuts is not None:
test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts)

View File

@ -759,6 +759,8 @@ def main():
test_ihm_cuts = ami.test_ihm_cuts()
dev_sdm_cuts = ami.dev_sdm_cuts()
test_sdm_cuts = ami.test_sdm_cuts()
dev_mdm_cuts = ami.dev_mdm_cuts()
test_mdm_cuts = ami.test_mdm_cuts()
dev_gss_cuts = ami.dev_gss_cuts()
test_gss_cuts = ami.test_gss_cuts()
@ -766,6 +768,8 @@ def main():
test_ihm_dl = ami.test_dataloaders(test_ihm_cuts)
dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts)
test_sdm_dl = ami.test_dataloaders(test_sdm_cuts)
dev_mdm_dl = ami.test_dataloaders(dev_mdm_cuts)
test_mdm_dl = ami.test_dataloaders(test_mdm_cuts)
if dev_gss_cuts is not None:
dev_gss_dl = ami.test_dataloaders(dev_gss_cuts)
if test_gss_cuts is not None:
@ -776,6 +780,8 @@ def main():
"test_ihm": (test_ihm_dl, test_ihm_cuts),
"dev_sdm": (dev_sdm_dl, dev_sdm_cuts),
"test_sdm": (test_sdm_dl, test_sdm_cuts),
"dev_mdm": (dev_mdm_dl, dev_mdm_cuts),
"test_mdm": (test_mdm_dl, test_mdm_cuts),
}
if dev_gss_cuts is not None:
test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts)