diff --git a/egs/librispeech/SSL/hubert/asr_datamodule.py b/egs/librispeech/SSL/hubert/asr_datamodule.py index 3746d8a3a..5e7b0cba9 100644 --- a/egs/librispeech/SSL/hubert/asr_datamodule.py +++ b/egs/librispeech/SSL/hubert/asr_datamodule.py @@ -132,6 +132,8 @@ class LibriSpeechAsrDataModule: cuts_train: CutSet, do_normalize: bool, sampler_state_dict: Optional[Dict[str, Any]] = None, + world_size: Optional[int] = None, + rank: Optional[int] = None, ) -> DataLoader: """ Args: @@ -151,6 +153,8 @@ class LibriSpeechAsrDataModule: shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, ) else: logging.info("Using SimpleCutSampler.") @@ -158,6 +162,8 @@ class LibriSpeechAsrDataModule: cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, + world_size=world_size, + rank=rank, ) logging.info("About to create train dataloader") @@ -181,13 +187,21 @@ class LibriSpeechAsrDataModule: return train_dl - def valid_dataloaders(self, cuts_valid: CutSet, do_normalize: bool) -> DataLoader: + def valid_dataloaders( + self, + cuts_valid: CutSet, + do_normalize: bool, + world_size: Optional[int] = None, + rank: Optional[int] = None, + ) -> DataLoader: logging.info("About to create dev dataset") validate = HubertAsrDataset(do_normalize=do_normalize) valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, shuffle=False, + world_size=world_size, + rank=rank, ) logging.info("About to create dev dataloader") valid_dl = DataLoader( diff --git a/egs/librispeech/SSL/hubert/ssl_datamodule.py b/egs/librispeech/SSL/hubert/ssl_datamodule.py index ac1a0997d..8c21fa352 100644 --- a/egs/librispeech/SSL/hubert/ssl_datamodule.py +++ b/egs/librispeech/SSL/hubert/ssl_datamodule.py @@ -144,6 +144,8 @@ class LibriSpeechDataModule: num_classes: list = [504], do_normalize: bool = True, sampler_state_dict: Optional[Dict[str, Any]] = None, + world_size: Optional[int] = None, + rank: Optional[int] = None, ) -> DataLoader: """ Args: @@ -171,6 +173,8 @@ class LibriSpeechDataModule: shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, ) else: logging.info("Using SimpleCutSampler.") @@ -178,6 +182,8 @@ class LibriSpeechDataModule: cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, + world_size=world_size, + rank=rank, ) logging.info("About to create train dataloader") @@ -211,6 +217,8 @@ class LibriSpeechDataModule: pad_audio: bool = False, num_classes: list = [504], do_normalize: bool = True, + world_size: Optional[int] = None, + rank: Optional[int] = None, ) -> DataLoader: logging.info("About to create dev dataset") validate = HubertDataset( @@ -226,6 +234,8 @@ class LibriSpeechDataModule: cuts_valid, max_duration=self.args.max_duration, shuffle=False, + world_size=world_size, + rank=rank, ) logging.info("About to create dev dataloader") valid_dl = DataLoader( diff --git a/egs/librispeech/SSL/pretrain.sh b/egs/librispeech/SSL/pretrain.sh new file mode 100755 index 000000000..98078ee6b --- /dev/null +++ b/egs/librispeech/SSL/pretrain.sh @@ -0,0 +1,20 @@ +./zipformer/pretrain.py \ + --world-size 8 \ + --num-epochs 300 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp_pretrain \ + --full-libri 1 \ + --max-duration 600 \ + --accum-grad 1 \ + --do-normalize 0 \ + --mask-prob 0.8 \ + --dropout-input 0.1 \ + --dropout-features 0.1 \ + --feature-grad-mult 0.1 \ + --untie-final-proj 1 \ + --num-encoder-layers 2,2,3,4,3,2 \ + --feedforward-dim 512,768,1024,1536,1024,768 \ + --encoder-dim 192,256,448,768,448,192 \ + --encoder-unmasked-dim 192,192,256,256,256,192 \ + --base-lr 0.045 diff --git a/egs/librispeech/SSL/zipformer/decode.py b/egs/librispeech/SSL/zipformer/decode.py old mode 100644 new mode 100755 diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py old mode 100644 new mode 100755 index bbb445320..55b1febaf --- a/egs/librispeech/SSL/zipformer/finetune.py +++ b/egs/librispeech/SSL/zipformer/finetune.py @@ -1387,6 +1387,8 @@ def run(rank, world_size, args): train_cuts, do_normalize=params.do_normalize, sampler_state_dict=sampler_state_dict, + world_size=world_size, + rank=rank, ) valid_cuts = librispeech.dev_clean_cuts() @@ -1395,6 +1397,8 @@ def run(rank, world_size, args): valid_dl = librispeech.valid_dataloaders( valid_cuts, do_normalize=params.do_normalize, + world_size=world_size, + rank=rank, ) if params.sanity_check and not params.print_diagnostics: diff --git a/egs/librispeech/SSL/zipformer/pretrain.py b/egs/librispeech/SSL/zipformer/pretrain.py old mode 100644 new mode 100755 index 5f547e0b8..acb63ed97 --- a/egs/librispeech/SSL/zipformer/pretrain.py +++ b/egs/librispeech/SSL/zipformer/pretrain.py @@ -41,7 +41,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" import argparse import copy import logging -import sys import warnings from pathlib import Path from shutil import copyfile @@ -594,7 +593,7 @@ def get_parser(): parser.add_argument( "--max-keep-size", type=int, - default=sys.maxsize, + default=320000, help="exclude sample longer than this.", ) @@ -1218,6 +1217,8 @@ def run(rank, world_size, args): num_classes=params.num_classes, do_normalize=params.do_normalize, sampler_state_dict=sampler_state_dict, + world_size=world_size, + rank=rank, ) valid_cuts = librispeech.dev_clean_cuts() @@ -1233,6 +1234,8 @@ def run(rank, world_size, args): pad_audio=False, num_classes=params.num_classes, do_normalize=params.do_normalize, + world_size=world_size, + rank=rank, ) if params.sanity_check and not params.print_diagnostics: