From f3f0dfc52dca4fe5e822a0f9e3ca471aa9f0d2c4 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Wed, 18 Oct 2023 15:07:12 +0800 Subject: [PATCH] minor updates --- egs/ami/ASR/local/compute_fbank_ami.py | 6 ++++++ .../pruned_transducer_stateless7/asr_datamodule.py | 13 ++++++++++++- egs/ami/ASR/pruned_transducer_stateless7/decode.py | 6 ++++++ egs/ami/ASR/zipformer/decode.py | 6 ++++++ 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/egs/ami/ASR/local/compute_fbank_ami.py b/egs/ami/ASR/local/compute_fbank_ami.py index 4892b40e3..c4a2541ea 100755 --- a/egs/ami/ASR/local/compute_fbank_ami.py +++ b/egs/ami/ASR/local/compute_fbank_ami.py @@ -77,6 +77,12 @@ def compute_fbank_ami(): prefix="ami-sdm", suffix="jsonl.gz", ) + manifests_sdm = read_manifests_if_cached( + dataset_parts=["dev", "test"], + output_dir=src_dir, + prefix="ami-mdm", + suffix="jsonl.gz", + ) # For GSS we already have cuts so we read them directly. manifests_gss = read_manifests_if_cached( dataset_parts=["train", "dev", "test"], diff --git a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py index 79474f1d8..592c1aaf9 100644 --- a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -294,7 +294,6 @@ class AmiAsrDataModule: return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] if self.args.concatenate_cuts: transforms = [ @@ -399,6 +398,12 @@ class AmiAsrDataModule: cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz") return cs.filter(self.remove_short_cuts) + @lru_cache() + def dev_mdm_cuts(self) -> CutSet: + logging.info("About to get AMI MDM dev cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_mdm.jsonl.gz") + return cs.filter(self.remove_short_cuts) + @lru_cache() def dev_gss_cuts(self) -> CutSet: if not (self.args.manifest_dir / "cuts_dev_gss.jsonl.gz").exists(): @@ -420,6 +425,12 @@ class AmiAsrDataModule: cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz") return cs.filter(self.remove_short_cuts) + @lru_cache() + def test_mdm_cuts(self) -> CutSet: + logging.info("About to get AMI MDM test cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_mdm.jsonl.gz") + return cs.filter(self.remove_short_cuts) + @lru_cache() def test_gss_cuts(self) -> CutSet: if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists(): diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decode.py b/egs/ami/ASR/pruned_transducer_stateless7/decode.py index 9999894d1..e205004d6 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/decode.py @@ -691,6 +691,8 @@ def main(): test_ihm_cuts = ami.test_ihm_cuts() dev_sdm_cuts = ami.dev_sdm_cuts() test_sdm_cuts = ami.test_sdm_cuts() + dev_mdm_cuts = ami.dev_mdm_cuts() + test_mdm_cuts = ami.test_mdm_cuts() dev_gss_cuts = ami.dev_gss_cuts() test_gss_cuts = ami.test_gss_cuts() @@ -698,6 +700,8 @@ def main(): test_ihm_dl = ami.test_dataloaders(test_ihm_cuts) dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts) test_sdm_dl = ami.test_dataloaders(test_sdm_cuts) + dev_mdm_dl = ami.test_dataloaders(dev_mdm_cuts) + test_mdm_dl = ami.test_dataloaders(test_mdm_cuts) if dev_gss_cuts is not None: dev_gss_dl = ami.test_dataloaders(dev_gss_cuts) if test_gss_cuts is not None: @@ -708,6 +712,8 @@ def main(): "test_ihm": (test_ihm_dl, test_ihm_cuts), "dev_sdm": (dev_sdm_dl, dev_sdm_cuts), "test_sdm": (test_sdm_dl, test_sdm_cuts), + "dev_mdm": (dev_mdm_dl, dev_mdm_cuts), + "test_mdm": (test_mdm_dl, test_mdm_cuts), } if dev_gss_cuts is not None: test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts) diff --git a/egs/ami/ASR/zipformer/decode.py b/egs/ami/ASR/zipformer/decode.py index be36d3609..5a5bf16eb 100755 --- a/egs/ami/ASR/zipformer/decode.py +++ b/egs/ami/ASR/zipformer/decode.py @@ -759,6 +759,8 @@ def main(): test_ihm_cuts = ami.test_ihm_cuts() dev_sdm_cuts = ami.dev_sdm_cuts() test_sdm_cuts = ami.test_sdm_cuts() + dev_mdm_cuts = ami.dev_mdm_cuts() + test_mdm_cuts = ami.test_mdm_cuts() dev_gss_cuts = ami.dev_gss_cuts() test_gss_cuts = ami.test_gss_cuts() @@ -766,6 +768,8 @@ def main(): test_ihm_dl = ami.test_dataloaders(test_ihm_cuts) dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts) test_sdm_dl = ami.test_dataloaders(test_sdm_cuts) + dev_mdm_dl = ami.test_dataloaders(dev_mdm_cuts) + test_mdm_dl = ami.test_dataloaders(test_mdm_cuts) if dev_gss_cuts is not None: dev_gss_dl = ami.test_dataloaders(dev_gss_cuts) if test_gss_cuts is not None: @@ -776,6 +780,8 @@ def main(): "test_ihm": (test_ihm_dl, test_ihm_cuts), "dev_sdm": (dev_sdm_dl, dev_sdm_cuts), "test_sdm": (test_sdm_dl, test_sdm_cuts), + "dev_mdm": (dev_mdm_dl, dev_mdm_cuts), + "test_mdm": (test_mdm_dl, test_mdm_cuts), } if dev_gss_cuts is not None: test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts)