mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
minor updates
This commit is contained in:
parent
21dad4506c
commit
09ada8fb48
@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
@ -24,14 +25,15 @@ from lhotse import CutSet, load_manifest_lazy
|
||||
|
||||
|
||||
class MultiDataset:
|
||||
def __init__(self, fbank_dir: str):
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
"""
|
||||
Args:
|
||||
manifest_dir:
|
||||
It is expected to contain the following files:
|
||||
- aishell2_cuts_train.jsonl.gz
|
||||
"""
|
||||
self.fbank_dir = Path(fbank_dir)
|
||||
self.fbank_dir = Path(args.fbank_dir)
|
||||
self.use_tal_csasr = args.use_tal_csasr
|
||||
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get multidataset train cuts")
|
||||
@ -42,11 +44,33 @@ class MultiDataset:
|
||||
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
# TAL-CSASR
|
||||
if self.use_tal_csasr:
|
||||
logging.info("Loading TAL-CSASR in lazy mode")
|
||||
tal_csasr_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "tal_csasr_cuts_train_set.jsonl.gz"
|
||||
)
|
||||
|
||||
# LibriSpeech
|
||||
train_clean_100_cuts = self.train_clean_100_cuts()
|
||||
train_clean_360_cuts = self.train_clean_360_cuts()
|
||||
train_other_500_cuts = self.train_other_500_cuts()
|
||||
|
||||
if self.use_tal_csasr:
|
||||
return CutSet.mux(
|
||||
aishell_2_cuts,
|
||||
train_clean_100_cuts,
|
||||
train_clean_360_cuts,
|
||||
train_other_500_cuts,
|
||||
tal_csasr_cuts,
|
||||
weights=[
|
||||
len(aishell_2_cuts),
|
||||
len(train_clean_100_cuts),
|
||||
len(train_clean_360_cuts),
|
||||
len(train_other_500_cuts),
|
||||
len(tal_csasr_cuts),
|
||||
],
|
||||
)
|
||||
return CutSet.mux(
|
||||
aishell_2_cuts,
|
||||
train_clean_100_cuts,
|
||||
@ -99,7 +123,7 @@ class MultiDataset:
|
||||
# LibriSpeech
|
||||
test_clean_cuts = self.test_clean_cuts()
|
||||
test_other_cuts = self.test_other_cuts()
|
||||
|
||||
|
||||
logging.info("Loading TAL-CSASR set in lazy mode")
|
||||
tal_csasr_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / "tal_csasr_cuts_test_set.jsonl.gz"
|
||||
|
@ -468,6 +468,13 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-tal-csasr",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use TAL-CSASR training data.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
Loading…
x
Reference in New Issue
Block a user