Update multi_dataset.py

This commit is contained in:
jinzr 2023-09-21 15:19:55 +08:00
parent 7eb2ba7d0d
commit 8a88678f91

View File

@ -15,13 +15,11 @@
# limitations under the License.
import glob
import logging
import re
from functools import lru_cache
from pathlib import Path
from typing import Dict, List
from typing import Dict
import lhotse
from lhotse import CutSet, load_manifest_lazy
@ -44,14 +42,25 @@ class MultiDataset:
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
)
# LibriSpeech
train_clean_100_cuts = self.train_clean_100_cuts()
train_clean_360_cuts = self.train_clean_360_cuts()
train_other_500_cuts = self.train_other_500_cuts()
return CutSet.mux(
aishell_2_cuts,
train_clean_100_cuts,
train_clean_360_cuts,
train_other_500_cuts,
weights=[
len(aishell_2_cuts),
len(train_clean_100_cuts),
len(train_clean_360_cuts),
len(train_other_500_cuts),
],
)
def dev_cuts(self) -> List[CutSet]:
def dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
# AISHELL-2
@ -60,9 +69,20 @@ class MultiDataset:
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
)
return [
# LibriSpeech
dev_clean_cuts = self.dev_clean_cuts()
dev_other_cuts = self.dev_other_cuts()
return CutSet.mux(
aishell2_dev_cuts,
]
dev_clean_cuts,
dev_other_cuts,
weights=[
len(aishell2_dev_cuts),
len(dev_clean_cuts),
len(dev_other_cuts),
],
)
def test_cuts(self) -> Dict[str, CutSet]:
logging.info("About to get multidataset test cuts")
@ -76,7 +96,62 @@ class MultiDataset:
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
)
# LibriSpeech
test_clean_cuts = self.test_clean_cuts()
test_other_cuts = self.test_other_cuts()
return {
"aishell-2_test": aishell2_test_cuts,
"aishell-2_dev": aishell2_dev_cuts,
"librispeech_test_clean": test_clean_cuts,
"librispeech_test_other": test_other_cuts,
}
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
)
@lru_cache()
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
@lru_cache()
def dev_clean_cuts(self) -> CutSet:
logging.info("About to get dev-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
)
@lru_cache()
def dev_other_cuts(self) -> CutSet:
logging.info("About to get dev-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
)
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get test-clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
)