mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
minor updates
This commit is contained in:
parent
21dad4506c
commit
09ada8fb48
@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -24,14 +25,15 @@ from lhotse import CutSet, load_manifest_lazy
|
|||||||
|
|
||||||
|
|
||||||
class MultiDataset:
|
class MultiDataset:
|
||||||
def __init__(self, fbank_dir: str):
|
def __init__(self, args: argparse.Namespace):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
manifest_dir:
|
manifest_dir:
|
||||||
It is expected to contain the following files:
|
It is expected to contain the following files:
|
||||||
- aishell2_cuts_train.jsonl.gz
|
- aishell2_cuts_train.jsonl.gz
|
||||||
"""
|
"""
|
||||||
self.fbank_dir = Path(fbank_dir)
|
self.fbank_dir = Path(args.fbank_dir)
|
||||||
|
self.use_tal_csasr = args.use_tal_csasr
|
||||||
|
|
||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get multidataset train cuts")
|
logging.info("About to get multidataset train cuts")
|
||||||
@ -42,11 +44,33 @@ class MultiDataset:
|
|||||||
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
self.fbank_dir / "aishell2_cuts_train.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TAL-CSASR
|
||||||
|
if self.use_tal_csasr:
|
||||||
|
logging.info("Loading TAL-CSASR in lazy mode")
|
||||||
|
tal_csasr_cuts = load_manifest_lazy(
|
||||||
|
self.fbank_dir / "tal_csasr_cuts_train_set.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
# LibriSpeech
|
# LibriSpeech
|
||||||
train_clean_100_cuts = self.train_clean_100_cuts()
|
train_clean_100_cuts = self.train_clean_100_cuts()
|
||||||
train_clean_360_cuts = self.train_clean_360_cuts()
|
train_clean_360_cuts = self.train_clean_360_cuts()
|
||||||
train_other_500_cuts = self.train_other_500_cuts()
|
train_other_500_cuts = self.train_other_500_cuts()
|
||||||
|
|
||||||
|
if self.use_tal_csasr:
|
||||||
|
return CutSet.mux(
|
||||||
|
aishell_2_cuts,
|
||||||
|
train_clean_100_cuts,
|
||||||
|
train_clean_360_cuts,
|
||||||
|
train_other_500_cuts,
|
||||||
|
tal_csasr_cuts,
|
||||||
|
weights=[
|
||||||
|
len(aishell_2_cuts),
|
||||||
|
len(train_clean_100_cuts),
|
||||||
|
len(train_clean_360_cuts),
|
||||||
|
len(train_other_500_cuts),
|
||||||
|
len(tal_csasr_cuts),
|
||||||
|
],
|
||||||
|
)
|
||||||
return CutSet.mux(
|
return CutSet.mux(
|
||||||
aishell_2_cuts,
|
aishell_2_cuts,
|
||||||
train_clean_100_cuts,
|
train_clean_100_cuts,
|
||||||
|
@ -468,6 +468,13 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-tal-csasr",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use TAL-CSASR training data.",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
Loading…
x
Reference in New Issue
Block a user