mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
minor updates on flags
This commit is contained in:
parent
4897f2c0f2
commit
cd20e21552
@ -311,6 +311,20 @@ def get_parser():
|
||||
help="Whether to use TAL-CSASR training data.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-librispeech",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use LibriSpeech training data.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-aishell2",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use Aishell-2 training data.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
@ -34,15 +34,18 @@ class MultiDataset:
|
||||
"""
|
||||
self.fbank_dir = Path(args.manifest_dir)
|
||||
self.use_tal_csasr = args.use_tal_csasr
|
||||
self.use_librispeech = args.use_librispeech
|
||||
self.use_aishell2 = args.use_aishell2
|
||||
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset train cuts")
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 in lazy mode")
|
||||
aishell_2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
||||
)
|
||||
if self.use_aishell2:
|
||||
logging.info("Loading Aishell-2 in lazy mode")
|
||||
aishell_2_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# TAL-CSASR
|
||||
if self.use_tal_csasr:
|
||||
@ -52,11 +55,13 @@ class MultiDataset:
|
||||
)
|
||||
|
||||
# LibriSpeech
|
||||
train_clean_100_cuts = self.train_clean_100_cuts()
|
||||
train_clean_360_cuts = self.train_clean_360_cuts()
|
||||
train_other_500_cuts = self.train_other_500_cuts()
|
||||
if self.use_librispeech:
|
||||
logging.info("Loading LibriSpeech in lazy mode")
|
||||
train_clean_100_cuts = self.train_clean_100_cuts()
|
||||
train_clean_360_cuts = self.train_clean_360_cuts()
|
||||
train_other_500_cuts = self.train_other_500_cuts()
|
||||
|
||||
if self.use_tal_csasr:
|
||||
if self.use_tal_csasr and self.use_librispeech and self.use_aishell2:
|
||||
return CutSet.mux(
|
||||
aishell_2_cuts,
|
||||
train_clean_100_cuts,
|
||||
@ -71,18 +76,43 @@ class MultiDataset:
|
||||
len(tal_csasr_cuts),
|
||||
],
|
||||
)
|
||||
return CutSet.mux(
|
||||
aishell_2_cuts,
|
||||
train_clean_100_cuts,
|
||||
train_clean_360_cuts,
|
||||
train_other_500_cuts,
|
||||
weights=[
|
||||
len(aishell_2_cuts),
|
||||
len(train_clean_100_cuts),
|
||||
len(train_clean_360_cuts),
|
||||
len(train_other_500_cuts),
|
||||
],
|
||||
)
|
||||
elif not self.use_tal_csasr and self.use_librispeech and self.use_aishell2:
|
||||
return CutSet.mux(
|
||||
aishell_2_cuts,
|
||||
train_clean_100_cuts,
|
||||
train_clean_360_cuts,
|
||||
train_other_500_cuts,
|
||||
weights=[
|
||||
len(aishell_2_cuts),
|
||||
len(train_clean_100_cuts),
|
||||
len(train_clean_360_cuts),
|
||||
len(train_other_500_cuts),
|
||||
],
|
||||
)
|
||||
elif self.use_tal_csasr and not self.use_librispeech and self.use_aishell2:
|
||||
return CutSet.mux(
|
||||
aishell_2_cuts,
|
||||
tal_csasr_cuts,
|
||||
weights=[
|
||||
len(aishell_2_cuts),
|
||||
len(tal_csasr_cuts),
|
||||
],
|
||||
)
|
||||
elif self.use_tal_csasr and self.use_librispeech and not self.use_aishell2:
|
||||
return CutSet.mux(
|
||||
train_clean_100_cuts,
|
||||
train_clean_360_cuts,
|
||||
train_other_500_cuts,
|
||||
tal_csasr_cuts,
|
||||
weights=[
|
||||
len(train_clean_100_cuts),
|
||||
len(train_clean_360_cuts),
|
||||
len(train_other_500_cuts),
|
||||
len(tal_csasr_cuts),
|
||||
],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset dev cuts")
|
||||
@ -97,14 +127,21 @@ class MultiDataset:
|
||||
dev_clean_cuts = self.dev_clean_cuts()
|
||||
dev_other_cuts = self.dev_other_cuts()
|
||||
|
||||
logging.info("Loading TAL-CSASR set in lazy mode")
|
||||
tal_csasr_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz"
|
||||
)
|
||||
|
||||
return CutSet.mux(
|
||||
aishell2_dev_cuts,
|
||||
dev_clean_cuts,
|
||||
dev_other_cuts,
|
||||
tal_csasr_dev_cuts,
|
||||
weights=[
|
||||
len(aishell2_dev_cuts),
|
||||
len(dev_clean_cuts),
|
||||
len(dev_other_cuts),
|
||||
len(tal_csasr_dev_cuts),
|
||||
],
|
||||
)
|
||||
|
||||
@ -112,31 +149,49 @@ class MultiDataset:
|
||||
logging.info("About to get multidataset test cuts")
|
||||
|
||||
# AISHELL-2
|
||||
logging.info("Loading Aishell-2 set in lazy mode")
|
||||
aishell2_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
|
||||
)
|
||||
aishell2_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
||||
)
|
||||
if self.use_aishell2:
|
||||
logging.info("Loading Aishell-2 set in lazy mode")
|
||||
aishell2_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_test.jsonl.gz"
|
||||
)
|
||||
aishell2_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
|
||||
)
|
||||
|
||||
# LibriSpeech
|
||||
test_clean_cuts = self.test_clean_cuts()
|
||||
test_other_cuts = self.test_other_cuts()
|
||||
if self.use_librispeech:
|
||||
test_clean_cuts = self.test_clean_cuts()
|
||||
test_other_cuts = self.test_other_cuts()
|
||||
|
||||
logging.info("Loading TAL-CSASR set in lazy mode")
|
||||
tal_csasr_cuts = load_manifest_lazy(
|
||||
tal_csasr_test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "tal_csasr_cuts_test_set.jsonl.gz"
|
||||
)
|
||||
tal_csasr_dev_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "tal_csasr_cuts_dev_set.jsonl.gz"
|
||||
)
|
||||
|
||||
return {
|
||||
"aishell-2_test": aishell2_test_cuts,
|
||||
"aishell-2_dev": aishell2_dev_cuts,
|
||||
"librispeech_test_clean": test_clean_cuts,
|
||||
"librispeech_test_other": test_other_cuts,
|
||||
"tal_csasr_cuts_test": tal_csasr_cuts,
|
||||
test_cuts = {
|
||||
"tal_csasr_test": tal_csasr_test_cuts,
|
||||
"tal_csasr_dev": tal_csasr_dev_cuts,
|
||||
}
|
||||
|
||||
if self.use_aishell2:
|
||||
test_cuts.update(
|
||||
{
|
||||
"aishell-2_test": aishell2_test_cuts,
|
||||
"aishell-2_dev": aishell2_dev_cuts,
|
||||
}
|
||||
)
|
||||
if self.use_librispeech:
|
||||
test_cuts.update(
|
||||
{
|
||||
"librispeech_test_clean": test_clean_cuts,
|
||||
"librispeech_test_other": test_other_cuts,
|
||||
}
|
||||
)
|
||||
return test_cuts
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
|
@ -475,6 +475,20 @@ def get_parser():
|
||||
help="Whether to use TAL-CSASR training data.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-librispeech",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use LibriSpeech training data.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-aishell2",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use Aishell-2 training data.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
Loading…
x
Reference in New Issue
Block a user