minor updates

This commit is contained in:
JinZr 2023-09-28 10:47:39 +08:00
parent 21dad4506c
commit 09ada8fb48
2 changed files with 34 additions and 3 deletions

View File

@ -15,6 +15,7 @@
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
@ -24,14 +25,15 @@ from lhotse import CutSet, load_manifest_lazy
class MultiDataset:
def __init__(self, fbank_dir: str):
def __init__(self, args: argparse.Namespace):
"""
Args:
manifest_dir:
It is expected to contain the following files:
- aishell2_cuts_train.jsonl.gz
"""
self.fbank_dir = Path(fbank_dir)
self.fbank_dir = Path(args.fbank_dir)
self.use_tal_csasr = args.use_tal_csasr
def train_cuts(self) -> CutSet:
logging.info("About to get multidataset train cuts")
@ -42,11 +44,33 @@ class MultiDataset:
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
)
# TAL-CSASR
if self.use_tal_csasr:
logging.info("Loading TAL-CSASR in lazy mode")
tal_csasr_cuts = load_manifest_lazy(
self.fbank_dir / "tal_csasr_cuts_train_set.jsonl.gz"
)
# LibriSpeech
train_clean_100_cuts = self.train_clean_100_cuts()
train_clean_360_cuts = self.train_clean_360_cuts()
train_other_500_cuts = self.train_other_500_cuts()
if self.use_tal_csasr:
return CutSet.mux(
aishell_2_cuts,
train_clean_100_cuts,
train_clean_360_cuts,
train_other_500_cuts,
tal_csasr_cuts,
weights=[
len(aishell_2_cuts),
len(train_clean_100_cuts),
len(train_clean_360_cuts),
len(train_other_500_cuts),
len(tal_csasr_cuts),
],
)
return CutSet.mux(
aishell_2_cuts,
train_clean_100_cuts,
@ -99,7 +123,7 @@ class MultiDataset:
# LibriSpeech
test_clean_cuts = self.test_clean_cuts()
test_other_cuts = self.test_other_cuts()
logging.info("Loading TAL-CSASR set in lazy mode")
tal_csasr_cuts = load_manifest_lazy(
self.fbank_dir / "tal_csasr_cuts_test_set.jsonl.gz"

View File

@ -468,6 +468,13 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--use-tal-csasr",
type=str2bool,
default=False,
help="Whether to use TAL-CSASR training data.",
)
add_model_arguments(parser)
return parser