mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 03:22:19 +00:00
commit
5aafaa35bd
@ -1,5 +1,6 @@
|
|||||||
# Copyright 2021 Piotr Żelasko
|
# Copyright 2021 Piotr Żelasko
|
||||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||||
|
# Copyright 2023 NVIDIA Corporation (Author: Wen Ding)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -43,7 +44,6 @@ from torch.utils.data import DataLoader
|
|||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
class _SeedWorkers:
|
class _SeedWorkers:
|
||||||
def __init__(self, seed: int):
|
def __init__(self, seed: int):
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
@ -52,7 +52,7 @@ class _SeedWorkers:
|
|||||||
fix_random_seed(self.seed + worker_id)
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
class LibriSpeechAsrDataModule:
|
class ICMCAsrDataModule:
|
||||||
"""
|
"""
|
||||||
DataModule for k2 ASR experiments.
|
DataModule for k2 ASR experiments.
|
||||||
It assumes there is always one train and valid dataloader,
|
It assumes there is always one train and valid dataloader,
|
||||||
@ -82,20 +82,19 @@ class LibriSpeechAsrDataModule:
|
|||||||
"effective batch sizes, sampling strategies, applied data "
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
"augmentations, etc.",
|
"augmentations, etc.",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--full-libri",
|
"--ihm-only",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="""Used only when --mini-libri is False.When enabled,
|
help="True for only use ihm data for training",
|
||||||
use 960h LibriSpeech. Otherwise, use 100h subset.""",
|
|
||||||
)
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--mini-libri",
|
"--full-data",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="True for mini librispeech",
|
help="True for all data",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--manifest-dir",
|
"--manifest-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
@ -402,74 +401,50 @@ class LibriSpeechAsrDataModule:
|
|||||||
return test_dl
|
return test_dl
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_clean_5_cuts(self) -> CutSet:
|
def train_ihm_cuts(self) -> CutSet:
|
||||||
logging.info("mini_librispeech: About to get train-clean-5 cuts")
|
logging.info("About to get train-ihm cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
|
self.args.manifest_dir / "cuts_train_ihm.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_clean_100_cuts(self) -> CutSet:
|
def train_ihm_rvb_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train-clean-100 cuts")
|
logging.info("About to get train-ihm-rvb cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
self.args.manifest_dir / "cuts_train_ihm_rvb.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_clean_360_cuts(self) -> CutSet:
|
def train_shm_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train-clean-360 cuts")
|
logging.info("About to get train-shm cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
self.args.manifest_dir / "cuts_train_sdm.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_other_500_cuts(self) -> CutSet:
|
def dev_ihm_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train-other-500 cuts")
|
logging.info("About to get dev-ihm cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
self.args.manifest_dir / "cuts_dev_ihm.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_all_shuf_cuts(self) -> CutSet:
|
def dev_shm_cuts(self) -> CutSet:
|
||||||
logging.info(
|
|
||||||
"About to get the shuffled train-clean-100, \
|
|
||||||
train-clean-360 and train-other-500 cuts"
|
|
||||||
)
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def dev_clean_2_cuts(self) -> CutSet:
|
|
||||||
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def dev_clean_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get dev-clean cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def dev_other_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get dev-other cuts")
|
logging.info("About to get dev-other cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
# @lru_cache()
|
||||||
def test_clean_cuts(self) -> CutSet:
|
# def test_clean_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test-clean cuts")
|
# logging.info("About to get test-clean cuts")
|
||||||
return load_manifest_lazy(
|
# return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
# self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||||
)
|
# )
|
||||||
|
|
||||||
@lru_cache()
|
# @lru_cache()
|
||||||
def test_other_cuts(self) -> CutSet:
|
# def test_other_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test-other cuts")
|
# logging.info("About to get test-other cuts")
|
||||||
return load_manifest_lazy(
|
# return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
# self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||||
)
|
# )
|
||||||
|
@ -25,13 +25,14 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
|
|
||||||
# For non-streaming model training:
|
# For non-streaming model training:
|
||||||
./zipformer/train.py \
|
./zipformer/train.py \
|
||||||
--world-size 4 \
|
--world-size 1 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
|
||||||
--exp-dir zipformer/exp \
|
--exp-dir zipformer/exp \
|
||||||
--full-libri 1 \
|
--manifest-dir '/mnt/samsung-t7/yuekai/asr/icefall-icmcasr/egs/icmcasr/ASR/data/manifests' \
|
||||||
--max-duration 1000
|
--max-duration 1000 \
|
||||||
|
--bpe-model /raid/wend/asr/icmc/multi_zh/icefall-asr-multi-zh-hans-zipformer-2023-9-2/data/lang_bpe_2000/bpe.model
|
||||||
|
|
||||||
|
|
||||||
# For streaming model training:
|
# For streaming model training:
|
||||||
./zipformer/train.py \
|
./zipformer/train.py \
|
||||||
@ -65,7 +66,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import ICMCAsrDataModule
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
@ -1172,12 +1173,12 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
icmc = ICMCAsrDataModule(args)
|
||||||
|
|
||||||
train_cuts = librispeech.train_clean_100_cuts()
|
train_cuts = icmc.train_ihm_cuts()
|
||||||
if params.full_libri:
|
if params.full_data:
|
||||||
train_cuts += librispeech.train_clean_360_cuts()
|
train_cuts += icmc.train_ihm_rvb_cuts()
|
||||||
train_cuts += librispeech.train_other_500_cuts()
|
train_cuts += icmc.train_shm_cuts()
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
@ -1225,13 +1226,13 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
|
|
||||||
train_dl = librispeech.train_dataloaders(
|
train_dl = icmc.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
valid_cuts = icmc.dev_ihm_cuts()
|
||||||
valid_cuts += librispeech.dev_other_cuts()
|
# valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = icmc.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
@ -1370,7 +1371,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
ICMCAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user