mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
minor updates
This commit is contained in:
parent
5340633654
commit
f3f0dfc52d
@ -77,6 +77,12 @@ def compute_fbank_ami():
|
||||
prefix="ami-sdm",
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
manifests_sdm = read_manifests_if_cached(
|
||||
dataset_parts=["dev", "test"],
|
||||
output_dir=src_dir,
|
||||
prefix="ami-mdm",
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
# For GSS we already have cuts so we read them directly.
|
||||
manifests_gss = read_manifests_if_cached(
|
||||
dataset_parts=["train", "dev", "test"],
|
||||
|
@ -294,7 +294,6 @@ class AmiAsrDataModule:
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
@ -399,6 +398,12 @@ class AmiAsrDataModule:
|
||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz")
|
||||
return cs.filter(self.remove_short_cuts)
|
||||
|
||||
@lru_cache()
|
||||
def dev_mdm_cuts(self) -> CutSet:
|
||||
logging.info("About to get AMI MDM dev cuts")
|
||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_mdm.jsonl.gz")
|
||||
return cs.filter(self.remove_short_cuts)
|
||||
|
||||
@lru_cache()
|
||||
def dev_gss_cuts(self) -> CutSet:
|
||||
if not (self.args.manifest_dir / "cuts_dev_gss.jsonl.gz").exists():
|
||||
@ -420,6 +425,12 @@ class AmiAsrDataModule:
|
||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz")
|
||||
return cs.filter(self.remove_short_cuts)
|
||||
|
||||
@lru_cache()
|
||||
def test_mdm_cuts(self) -> CutSet:
|
||||
logging.info("About to get AMI MDM test cuts")
|
||||
cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_mdm.jsonl.gz")
|
||||
return cs.filter(self.remove_short_cuts)
|
||||
|
||||
@lru_cache()
|
||||
def test_gss_cuts(self) -> CutSet:
|
||||
if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists():
|
||||
|
@ -691,6 +691,8 @@ def main():
|
||||
test_ihm_cuts = ami.test_ihm_cuts()
|
||||
dev_sdm_cuts = ami.dev_sdm_cuts()
|
||||
test_sdm_cuts = ami.test_sdm_cuts()
|
||||
dev_mdm_cuts = ami.dev_mdm_cuts()
|
||||
test_mdm_cuts = ami.test_mdm_cuts()
|
||||
dev_gss_cuts = ami.dev_gss_cuts()
|
||||
test_gss_cuts = ami.test_gss_cuts()
|
||||
|
||||
@ -698,6 +700,8 @@ def main():
|
||||
test_ihm_dl = ami.test_dataloaders(test_ihm_cuts)
|
||||
dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts)
|
||||
test_sdm_dl = ami.test_dataloaders(test_sdm_cuts)
|
||||
dev_mdm_dl = ami.test_dataloaders(dev_mdm_cuts)
|
||||
test_mdm_dl = ami.test_dataloaders(test_mdm_cuts)
|
||||
if dev_gss_cuts is not None:
|
||||
dev_gss_dl = ami.test_dataloaders(dev_gss_cuts)
|
||||
if test_gss_cuts is not None:
|
||||
@ -708,6 +712,8 @@ def main():
|
||||
"test_ihm": (test_ihm_dl, test_ihm_cuts),
|
||||
"dev_sdm": (dev_sdm_dl, dev_sdm_cuts),
|
||||
"test_sdm": (test_sdm_dl, test_sdm_cuts),
|
||||
"dev_mdm": (dev_mdm_dl, dev_mdm_cuts),
|
||||
"test_mdm": (test_mdm_dl, test_mdm_cuts),
|
||||
}
|
||||
if dev_gss_cuts is not None:
|
||||
test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts)
|
||||
|
@ -759,6 +759,8 @@ def main():
|
||||
test_ihm_cuts = ami.test_ihm_cuts()
|
||||
dev_sdm_cuts = ami.dev_sdm_cuts()
|
||||
test_sdm_cuts = ami.test_sdm_cuts()
|
||||
dev_mdm_cuts = ami.dev_mdm_cuts()
|
||||
test_mdm_cuts = ami.test_mdm_cuts()
|
||||
dev_gss_cuts = ami.dev_gss_cuts()
|
||||
test_gss_cuts = ami.test_gss_cuts()
|
||||
|
||||
@ -766,6 +768,8 @@ def main():
|
||||
test_ihm_dl = ami.test_dataloaders(test_ihm_cuts)
|
||||
dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts)
|
||||
test_sdm_dl = ami.test_dataloaders(test_sdm_cuts)
|
||||
dev_mdm_dl = ami.test_dataloaders(dev_mdm_cuts)
|
||||
test_mdm_dl = ami.test_dataloaders(test_mdm_cuts)
|
||||
if dev_gss_cuts is not None:
|
||||
dev_gss_dl = ami.test_dataloaders(dev_gss_cuts)
|
||||
if test_gss_cuts is not None:
|
||||
@ -776,6 +780,8 @@ def main():
|
||||
"test_ihm": (test_ihm_dl, test_ihm_cuts),
|
||||
"dev_sdm": (dev_sdm_dl, dev_sdm_cuts),
|
||||
"test_sdm": (test_sdm_dl, test_sdm_cuts),
|
||||
"dev_mdm": (dev_mdm_dl, dev_mdm_cuts),
|
||||
"test_mdm": (test_mdm_dl, test_mdm_cuts),
|
||||
}
|
||||
if dev_gss_cuts is not None:
|
||||
test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts)
|
||||
|
Loading…
x
Reference in New Issue
Block a user