mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
minor updates
This commit is contained in:
parent
5340633654
commit
f3f0dfc52d
@ -77,6 +77,12 @@ def compute_fbank_ami():
|
|||||||
prefix="ami-sdm",
|
prefix="ami-sdm",
|
||||||
suffix="jsonl.gz",
|
suffix="jsonl.gz",
|
||||||
)
|
)
|
||||||
|
manifests_sdm = read_manifests_if_cached(
|
||||||
|
dataset_parts=["dev", "test"],
|
||||||
|
output_dir=src_dir,
|
||||||
|
prefix="ami-mdm",
|
||||||
|
suffix="jsonl.gz",
|
||||||
|
)
|
||||||
# For GSS we already have cuts so we read them directly.
|
# For GSS we already have cuts so we read them directly.
|
||||||
manifests_gss = read_manifests_if_cached(
|
manifests_gss = read_manifests_if_cached(
|
||||||
dataset_parts=["train", "dev", "test"],
|
dataset_parts=["train", "dev", "test"],
|
||||||
|
@ -294,7 +294,6 @@ class AmiAsrDataModule:
|
|||||||
return train_dl
|
return train_dl
|
||||||
|
|
||||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
|
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.concatenate_cuts:
|
if self.args.concatenate_cuts:
|
||||||
transforms = [
|
transforms = [
|
||||||
@ -399,6 +398,12 @@ class AmiAsrDataModule:
|
|||||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz")
|
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz")
|
||||||
return cs.filter(self.remove_short_cuts)
|
return cs.filter(self.remove_short_cuts)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def dev_mdm_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get AMI MDM dev cuts")
|
||||||
|
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_mdm.jsonl.gz")
|
||||||
|
return cs.filter(self.remove_short_cuts)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def dev_gss_cuts(self) -> CutSet:
|
def dev_gss_cuts(self) -> CutSet:
|
||||||
if not (self.args.manifest_dir / "cuts_dev_gss.jsonl.gz").exists():
|
if not (self.args.manifest_dir / "cuts_dev_gss.jsonl.gz").exists():
|
||||||
@ -420,6 +425,12 @@ class AmiAsrDataModule:
|
|||||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz")
|
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz")
|
||||||
return cs.filter(self.remove_short_cuts)
|
return cs.filter(self.remove_short_cuts)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_mdm_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get AMI MDM test cuts")
|
||||||
|
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_mdm.jsonl.gz")
|
||||||
|
return cs.filter(self.remove_short_cuts)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_gss_cuts(self) -> CutSet:
|
def test_gss_cuts(self) -> CutSet:
|
||||||
if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists():
|
if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists():
|
||||||
|
@ -691,6 +691,8 @@ def main():
|
|||||||
test_ihm_cuts = ami.test_ihm_cuts()
|
test_ihm_cuts = ami.test_ihm_cuts()
|
||||||
dev_sdm_cuts = ami.dev_sdm_cuts()
|
dev_sdm_cuts = ami.dev_sdm_cuts()
|
||||||
test_sdm_cuts = ami.test_sdm_cuts()
|
test_sdm_cuts = ami.test_sdm_cuts()
|
||||||
|
dev_mdm_cuts = ami.dev_mdm_cuts()
|
||||||
|
test_mdm_cuts = ami.test_mdm_cuts()
|
||||||
dev_gss_cuts = ami.dev_gss_cuts()
|
dev_gss_cuts = ami.dev_gss_cuts()
|
||||||
test_gss_cuts = ami.test_gss_cuts()
|
test_gss_cuts = ami.test_gss_cuts()
|
||||||
|
|
||||||
@ -698,6 +700,8 @@ def main():
|
|||||||
test_ihm_dl = ami.test_dataloaders(test_ihm_cuts)
|
test_ihm_dl = ami.test_dataloaders(test_ihm_cuts)
|
||||||
dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts)
|
dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts)
|
||||||
test_sdm_dl = ami.test_dataloaders(test_sdm_cuts)
|
test_sdm_dl = ami.test_dataloaders(test_sdm_cuts)
|
||||||
|
dev_mdm_dl = ami.test_dataloaders(dev_mdm_cuts)
|
||||||
|
test_mdm_dl = ami.test_dataloaders(test_mdm_cuts)
|
||||||
if dev_gss_cuts is not None:
|
if dev_gss_cuts is not None:
|
||||||
dev_gss_dl = ami.test_dataloaders(dev_gss_cuts)
|
dev_gss_dl = ami.test_dataloaders(dev_gss_cuts)
|
||||||
if test_gss_cuts is not None:
|
if test_gss_cuts is not None:
|
||||||
@ -708,6 +712,8 @@ def main():
|
|||||||
"test_ihm": (test_ihm_dl, test_ihm_cuts),
|
"test_ihm": (test_ihm_dl, test_ihm_cuts),
|
||||||
"dev_sdm": (dev_sdm_dl, dev_sdm_cuts),
|
"dev_sdm": (dev_sdm_dl, dev_sdm_cuts),
|
||||||
"test_sdm": (test_sdm_dl, test_sdm_cuts),
|
"test_sdm": (test_sdm_dl, test_sdm_cuts),
|
||||||
|
"dev_mdm": (dev_mdm_dl, dev_mdm_cuts),
|
||||||
|
"test_mdm": (test_mdm_dl, test_mdm_cuts),
|
||||||
}
|
}
|
||||||
if dev_gss_cuts is not None:
|
if dev_gss_cuts is not None:
|
||||||
test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts)
|
test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts)
|
||||||
|
@ -759,6 +759,8 @@ def main():
|
|||||||
test_ihm_cuts = ami.test_ihm_cuts()
|
test_ihm_cuts = ami.test_ihm_cuts()
|
||||||
dev_sdm_cuts = ami.dev_sdm_cuts()
|
dev_sdm_cuts = ami.dev_sdm_cuts()
|
||||||
test_sdm_cuts = ami.test_sdm_cuts()
|
test_sdm_cuts = ami.test_sdm_cuts()
|
||||||
|
dev_mdm_cuts = ami.dev_mdm_cuts()
|
||||||
|
test_mdm_cuts = ami.test_mdm_cuts()
|
||||||
dev_gss_cuts = ami.dev_gss_cuts()
|
dev_gss_cuts = ami.dev_gss_cuts()
|
||||||
test_gss_cuts = ami.test_gss_cuts()
|
test_gss_cuts = ami.test_gss_cuts()
|
||||||
|
|
||||||
@ -766,6 +768,8 @@ def main():
|
|||||||
test_ihm_dl = ami.test_dataloaders(test_ihm_cuts)
|
test_ihm_dl = ami.test_dataloaders(test_ihm_cuts)
|
||||||
dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts)
|
dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts)
|
||||||
test_sdm_dl = ami.test_dataloaders(test_sdm_cuts)
|
test_sdm_dl = ami.test_dataloaders(test_sdm_cuts)
|
||||||
|
dev_mdm_dl = ami.test_dataloaders(dev_mdm_cuts)
|
||||||
|
test_mdm_dl = ami.test_dataloaders(test_mdm_cuts)
|
||||||
if dev_gss_cuts is not None:
|
if dev_gss_cuts is not None:
|
||||||
dev_gss_dl = ami.test_dataloaders(dev_gss_cuts)
|
dev_gss_dl = ami.test_dataloaders(dev_gss_cuts)
|
||||||
if test_gss_cuts is not None:
|
if test_gss_cuts is not None:
|
||||||
@ -776,6 +780,8 @@ def main():
|
|||||||
"test_ihm": (test_ihm_dl, test_ihm_cuts),
|
"test_ihm": (test_ihm_dl, test_ihm_cuts),
|
||||||
"dev_sdm": (dev_sdm_dl, dev_sdm_cuts),
|
"dev_sdm": (dev_sdm_dl, dev_sdm_cuts),
|
||||||
"test_sdm": (test_sdm_dl, test_sdm_cuts),
|
"test_sdm": (test_sdm_dl, test_sdm_cuts),
|
||||||
|
"dev_mdm": (dev_mdm_dl, dev_mdm_cuts),
|
||||||
|
"test_mdm": (test_mdm_dl, test_mdm_cuts),
|
||||||
}
|
}
|
||||||
if dev_gss_cuts is not None:
|
if dev_gss_cuts is not None:
|
||||||
test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts)
|
test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user