update the asr datamodule

This commit is contained in:
marcoyang 2024-02-27 18:02:16 +08:00
parent ab76630e0d
commit a9edd7cc3d

View File

@ -84,19 +84,12 @@ class MLSAsrDataModule:
)
group.add_argument(
"--language",
type=str2bool,
type=str,
default="all",
choices=["english", "german", "dutch", "french", "spanish", "italian", "portuguese", "polish", "all"],
help="""If all, use all the languages, other
use 960h LibriSpeech. Otherwise, use 100h subset.""",
# choices=["english", "german", "dutch", "french", "spanish", "italian", "portuguese", "polish", "all"],
help="""A list of languages separated by comma. If all, use all
the languages""",
)
group.add_argument(
"--mini-libri",
type=str2bool,
default=False,
help="True for mini librispeech",
)
group.add_argument(
"--manifest-dir",
type=Path,
@ -405,89 +398,37 @@ class MLSAsrDataModule:
return test_dl
@lru_cache()
def train_clean_5_cuts(self) -> CutSet:
logging.info("mini_librispeech: About to get train-clean-5 cuts")
def train_mls_cuts(self) -> CutSet:
if self.args.language == "all":
languages = ["english", "german", "dutch", "french", "spanish", "italian", "portuguese", "polish",]
else:
languages = self.args.language.split(",")
if len(languages) == 1:
l = languages[0]
logging.info(f"About to get {l} cuts")
return load_manifest_lazy(
self.args.manifest_dir / f"mls-{l}_train.jsonl.gz"
)
else:
all_cuts = []
all_cuts_len = []
for l in languages:
logging.info(f"About to get {l} cuts")
current_cuts = load_manifest_lazy(
self.args.manifest_dir / f"mls-{l}_train.jsonl.gz"
)
current_cuts_len = len(current_cuts)
all_cuts.append(current_cuts)
all_cuts_len.append(current_cuts_len)
return CutSet.mux(
*all_cuts,
weights=all_cuts_len,
stop_early=True,
)
@lru_cache()
def mls_dev_cuts(self, language: str) -> CutSet:
logging.info(f"About to get dev cuts for {language}")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
)
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
)
@lru_cache()
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
@lru_cache()
def dev_clean_2_cuts(self) -> CutSet:
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)
@lru_cache()
def gigaspeech_subset_small_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech subset-S cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_S.jsonl.gz")
@lru_cache()
def gigaspeech_dev_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech dev cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
@lru_cache()
def gigaspeech_test_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech test cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
self.args.manifest_dir / f"mls-{language}_dev.jsonl.gz"
)