minor updates on flags

This commit is contained in:
JinZr 2023-11-22 17:24:21 +08:00
parent 4897f2c0f2
commit cd20e21552
3 changed files with 119 additions and 36 deletions

View File

@ -311,6 +311,20 @@ def get_parser():
help="Whether to use TAL-CSASR training data.",
)
parser.add_argument(
"--use-librispeech",
type=str2bool,
default=False,
help="Whether to use LibriSpeech training data.",
)
parser.add_argument(
"--use-aishell2",
type=str2bool,
default=False,
help="Whether to use Aishell-2 training data.",
)
add_model_arguments(parser)
return parser

View File

@ -34,15 +34,18 @@ class MultiDataset:
"""
self.fbank_dir = Path(args.manifest_dir)
self.use_tal_csasr = args.use_tal_csasr
self.use_librispeech = args.use_librispeech
self.use_aishell2 = args.use_aishell2
def train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
# AISHELL-2
logging.info("Loading Aishell-2 in lazy mode")
aishell_2_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
)
if self.use_aishell2:
logging.info("Loading Aishell-2 in lazy mode")
aishell_2_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
)
# TAL-CSASR
if self.use_tal_csasr:
@ -52,11 +55,13 @@ class MultiDataset:
)
# LibriSpeech
train_clean_100_cuts = self.train_clean_100_cuts()
train_clean_360_cuts = self.train_clean_360_cuts()
train_other_500_cuts = self.train_other_500_cuts()
if self.use_librispeech:
logging.info("Loading LibriSpeech in lazy mode")
train_clean_100_cuts = self.train_clean_100_cuts()
train_clean_360_cuts = self.train_clean_360_cuts()
train_other_500_cuts = self.train_other_500_cuts()
if self.use_tal_csasr:
if self.use_tal_csasr and self.use_librispeech and self.use_aishell2:
return CutSet.mux(
aishell_2_cuts,
train_clean_100_cuts,
@ -71,18 +76,43 @@ class MultiDataset:
len(tal_csasr_cuts),
],
)
return CutSet.mux(
aishell_2_cuts,
train_clean_100_cuts,
train_clean_360_cuts,
train_other_500_cuts,
weights=[
len(aishell_2_cuts),
len(train_clean_100_cuts),
len(train_clean_360_cuts),
len(train_other_500_cuts),
],
)
elif not self.use_tal_csasr and self.use_librispeech and self.use_aishell2:
return CutSet.mux(
aishell_2_cuts,
train_clean_100_cuts,
train_clean_360_cuts,
train_other_500_cuts,
weights=[
len(aishell_2_cuts),
len(train_clean_100_cuts),
len(train_clean_360_cuts),
len(train_other_500_cuts),
],
)
elif self.use_tal_csasr and not self.use_librispeech and self.use_aishell2:
return CutSet.mux(
aishell_2_cuts,
tal_csasr_cuts,
weights=[
len(aishell_2_cuts),
len(tal_csasr_cuts),
],
)
elif self.use_tal_csasr and self.use_librispeech and not self.use_aishell2:
return CutSet.mux(
train_clean_100_cuts,
train_clean_360_cuts,
train_other_500_cuts,
tal_csasr_cuts,
weights=[
len(train_clean_100_cuts),
len(train_clean_360_cuts),
len(train_other_500_cuts),
len(tal_csasr_cuts),
],
)
else:
raise NotImplementedError
def dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")
@ -97,14 +127,21 @@ class MultiDataset:
dev_clean_cuts = self.dev_clean_cuts()
dev_other_cuts = self.dev_other_cuts()
logging.info("Loading TAL-CSASR set in lazy mode")
tal_csasr_dev_cuts = load_manifest_lazy(
self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz"
)
return CutSet.mux(
aishell2_dev_cuts,
dev_clean_cuts,
dev_other_cuts,
tal_csasr_dev_cuts,
weights=[
len(aishell2_dev_cuts),
len(dev_clean_cuts),
len(dev_other_cuts),
len(tal_csasr_dev_cuts),
],
)
@ -112,31 +149,49 @@ class MultiDataset:
logging.info("About to get multidataset test cuts")
# AISHELL-2
logging.info("Loading Aishell-2 set in lazy mode")
aishell2_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
)
aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
)
if self.use_aishell2:
logging.info("Loading Aishell-2 set in lazy mode")
aishell2_test_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
)
aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
)
# LibriSpeech
test_clean_cuts = self.test_clean_cuts()
test_other_cuts = self.test_other_cuts()
if self.use_librispeech:
test_clean_cuts = self.test_clean_cuts()
test_other_cuts = self.test_other_cuts()
logging.info("Loading TAL-CSASR set in lazy mode")
tal_csasr_cuts = load_manifest_lazy(
tal_csasr_test_cuts = load_manifest_lazy(
self.fbank_dir / "tal_csasr_cuts_test_set.jsonl.gz"
)
tal_csasr_dev_cuts = load_manifest_lazy(
self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz"
)
return {
"aishell-2_test": aishell2_test_cuts,
"aishell-2_dev": aishell2_dev_cuts,
"librispeech_test_clean": test_clean_cuts,
"librispeech_test_other": test_other_cuts,
"tal_csasr_cuts_test": tal_csasr_cuts,
test_cuts = {
"tal_csasr_test": tal_csasr_test_cuts,
"tal_csasr_dev": tal_csasr_dev_cuts,
}
if self.use_aishell2:
test_cuts.update(
{
"aishell-2_test": aishell2_test_cuts,
"aishell-2_dev": aishell2_dev_cuts,
}
)
if self.use_librispeech:
test_cuts.update(
{
"librispeech_test_clean": test_clean_cuts,
"librispeech_test_other": test_other_cuts,
}
)
return test_cuts
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")

View File

@ -475,6 +475,20 @@ def get_parser():
help="Whether to use TAL-CSASR training data.",
)
parser.add_argument(
"--use-librispeech",
type=str2bool,
default=False,
help="Whether to use LibriSpeech training data.",
)
parser.add_argument(
"--use-aishell2",
type=str2bool,
default=False,
help="Whether to use Aishell-2 training data.",
)
add_model_arguments(parser)
return parser