From 6a9cce407cf63ca1e3c60085363df77784aaaef3 Mon Sep 17 00:00:00 2001 From: wd929 Date: Fri, 20 Oct 2023 17:01:49 +0800 Subject: [PATCH] Training scripts for ICMC Signed-off-by: wd929 --- egs/icmcasr/ASR/zipformer/asr_datamodule.py | 91 ++++++++------------- egs/icmcasr/ASR/zipformer/train.py | 31 +++---- 2 files changed, 49 insertions(+), 73 deletions(-) diff --git a/egs/icmcasr/ASR/zipformer/asr_datamodule.py b/egs/icmcasr/ASR/zipformer/asr_datamodule.py index 20df469da..cd6d8ea7f 100644 --- a/egs/icmcasr/ASR/zipformer/asr_datamodule.py +++ b/egs/icmcasr/ASR/zipformer/asr_datamodule.py @@ -1,5 +1,6 @@ # Copyright 2021 Piotr Żelasko # Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2023 NVIDIA Corporation (Author: Wen Ding) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -43,7 +44,6 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool - class _SeedWorkers: def __init__(self, seed: int): self.seed = seed @@ -52,7 +52,7 @@ class _SeedWorkers: fix_random_seed(self.seed + worker_id) -class LibriSpeechAsrDataModule: +class ICMCAsrDataModule: """ DataModule for k2 ASR experiments. It assumes there is always one train and valid dataloader, @@ -82,20 +82,19 @@ class LibriSpeechAsrDataModule: "effective batch sizes, sampling strategies, applied data " "augmentations, etc.", ) + group.add_argument( - "--full-libri", + "--ihm-only", type=str2bool, default=True, - help="""Used only when --mini-libri is False.When enabled, - use 960h LibriSpeech. Otherwise, use 100h subset.""", + help="True for only use ihm data for training", ) group.add_argument( - "--mini-libri", + "--full-data", type=str2bool, default=False, - help="True for mini librispeech", + help="True for all data", ) - group.add_argument( "--manifest-dir", type=Path, @@ -402,74 +401,50 @@ class LibriSpeechAsrDataModule: return test_dl @lru_cache() - def train_clean_5_cuts(self) -> CutSet: - logging.info("mini_librispeech: About to get train-clean-5 cuts") + def train_ihm_cuts(self) -> CutSet: + logging.info("About to get train-ihm cuts") return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" + self.args.manifest_dir / "cuts_train_ihm.jsonl.gz" ) @lru_cache() - def train_clean_100_cuts(self) -> CutSet: - logging.info("About to get train-clean-100 cuts") + def train_ihm_rvb_cuts(self) -> CutSet: + logging.info("About to get train-ihm-rvb cuts") return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + self.args.manifest_dir / "cuts_train_ihm_rvb.jsonl.gz" ) @lru_cache() - def train_clean_360_cuts(self) -> CutSet: - logging.info("About to get train-clean-360 cuts") + def train_shm_cuts(self) -> CutSet: + logging.info("About to get train-shm cuts") return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" + self.args.manifest_dir / "cuts_train_sdm.jsonl.gz" ) @lru_cache() - def train_other_500_cuts(self) -> CutSet: - logging.info("About to get train-other-500 cuts") + def dev_ihm_cuts(self) -> CutSet: + logging.info("About to get dev-ihm cuts") return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" + self.args.manifest_dir / "cuts_dev_ihm.jsonl.gz" ) @lru_cache() - def train_all_shuf_cuts(self) -> CutSet: - logging.info( - "About to get the shuffled train-clean-100, \ - train-clean-360 and train-other-500 cuts" - ) - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" - ) - - @lru_cache() - def dev_clean_2_cuts(self) -> CutSet: - logging.info("mini_librispeech: About to get dev-clean-2 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz" - ) - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" - ) - - @lru_cache() - def dev_other_cuts(self) -> CutSet: + def dev_shm_cuts(self) -> CutSet: logging.info("About to get dev-other cuts") return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" + self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz" ) - @lru_cache() - def test_clean_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" - ) + # @lru_cache() + # def test_clean_cuts(self) -> CutSet: + # logging.info("About to get test-clean cuts") + # return load_manifest_lazy( + # self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" + # ) - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" - ) + # @lru_cache() + # def test_other_cuts(self) -> CutSet: + # logging.info("About to get test-other cuts") + # return load_manifest_lazy( + # self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" + # ) diff --git a/egs/icmcasr/ASR/zipformer/train.py b/egs/icmcasr/ASR/zipformer/train.py index 7009f3346..413b3c661 100755 --- a/egs/icmcasr/ASR/zipformer/train.py +++ b/egs/icmcasr/ASR/zipformer/train.py @@ -25,13 +25,14 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" # For non-streaming model training: ./zipformer/train.py \ - --world-size 4 \ + --world-size 1 \ --num-epochs 30 \ --start-epoch 1 \ - --use-fp16 1 \ --exp-dir zipformer/exp \ - --full-libri 1 \ - --max-duration 1000 + --manifest-dir '/mnt/samsung-t7/yuekai/asr/icefall-icmcasr/egs/icmcasr/ASR/data/manifests' \ + --max-duration 1000 \ + --bpe-model /raid/wend/asr/icmc/multi_zh/icefall-asr-multi-zh-hans-zipformer-2023-9-2/data/lang_bpe_2000/bpe.model + # For streaming model training: ./zipformer/train.py \ @@ -65,7 +66,7 @@ import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import ICMCAsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -1172,12 +1173,12 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - librispeech = LibriSpeechAsrDataModule(args) + icmc = ICMCAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = icmc.train_ihm_cuts() + if params.full_data: + train_cuts += icmc.train_ihm_rvb_cuts() + train_cuts += icmc.train_shm_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1225,13 +1226,13 @@ def run(rank, world_size, args): else: sampler_state_dict = None - train_dl = librispeech.train_dataloaders( + train_dl = icmc.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) + valid_cuts = icmc.dev_ihm_cuts() + # valid_cuts += librispeech.dev_other_cuts() + valid_dl = icmc.valid_dataloaders(valid_cuts) if not params.print_diagnostics: scan_pessimistic_batches_for_oom( @@ -1370,7 +1371,7 @@ def scan_pessimistic_batches_for_oom( def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + ICMCAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir)