Fix for multi-gpu

This commit is contained in:
yifanyeung 2024-05-04 17:36:33 +08:00
parent c08fe48603
commit 322baa2593
6 changed files with 54 additions and 3 deletions

View File

@ -132,6 +132,8 @@ class LibriSpeechAsrDataModule:
cuts_train: CutSet, cuts_train: CutSet,
do_normalize: bool, do_normalize: bool,
sampler_state_dict: Optional[Dict[str, Any]] = None, sampler_state_dict: Optional[Dict[str, Any]] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader: ) -> DataLoader:
""" """
Args: Args:
@ -151,6 +153,8 @@ class LibriSpeechAsrDataModule:
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
) )
else: else:
logging.info("Using SimpleCutSampler.") logging.info("Using SimpleCutSampler.")
@ -158,6 +162,8 @@ class LibriSpeechAsrDataModule:
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
world_size=world_size,
rank=rank,
) )
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
@ -181,13 +187,21 @@ class LibriSpeechAsrDataModule:
return train_dl return train_dl
def valid_dataloaders(self, cuts_valid: CutSet, do_normalize: bool) -> DataLoader: def valid_dataloaders(
self,
cuts_valid: CutSet,
do_normalize: bool,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
validate = HubertAsrDataset(do_normalize=do_normalize) validate = HubertAsrDataset(do_normalize=do_normalize)
valid_sampler = DynamicBucketingSampler( valid_sampler = DynamicBucketingSampler(
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
world_size=world_size,
rank=rank,
) )
logging.info("About to create dev dataloader") logging.info("About to create dev dataloader")
valid_dl = DataLoader( valid_dl = DataLoader(

View File

@ -144,6 +144,8 @@ class LibriSpeechDataModule:
num_classes: list = [504], num_classes: list = [504],
do_normalize: bool = True, do_normalize: bool = True,
sampler_state_dict: Optional[Dict[str, Any]] = None, sampler_state_dict: Optional[Dict[str, Any]] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader: ) -> DataLoader:
""" """
Args: Args:
@ -171,6 +173,8 @@ class LibriSpeechDataModule:
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
) )
else: else:
logging.info("Using SimpleCutSampler.") logging.info("Using SimpleCutSampler.")
@ -178,6 +182,8 @@ class LibriSpeechDataModule:
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
world_size=world_size,
rank=rank,
) )
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
@ -211,6 +217,8 @@ class LibriSpeechDataModule:
pad_audio: bool = False, pad_audio: bool = False,
num_classes: list = [504], num_classes: list = [504],
do_normalize: bool = True, do_normalize: bool = True,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader: ) -> DataLoader:
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
validate = HubertDataset( validate = HubertDataset(
@ -226,6 +234,8 @@ class LibriSpeechDataModule:
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
world_size=world_size,
rank=rank,
) )
logging.info("About to create dev dataloader") logging.info("About to create dev dataloader")
valid_dl = DataLoader( valid_dl = DataLoader(

20
egs/librispeech/SSL/pretrain.sh Executable file
View File

@ -0,0 +1,20 @@
./zipformer/pretrain.py \
--world-size 8 \
--num-epochs 300 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp_pretrain \
--full-libri 1 \
--max-duration 600 \
--accum-grad 1 \
--do-normalize 0 \
--mask-prob 0.8 \
--dropout-input 0.1 \
--dropout-features 0.1 \
--feature-grad-mult 0.1 \
--untie-final-proj 1 \
--num-encoder-layers 2,2,3,4,3,2 \
--feedforward-dim 512,768,1024,1536,1024,768 \
--encoder-dim 192,256,448,768,448,192 \
--encoder-unmasked-dim 192,192,256,256,256,192 \
--base-lr 0.045

0
egs/librispeech/SSL/zipformer/decode.py Normal file → Executable file
View File

4
egs/librispeech/SSL/zipformer/finetune.py Normal file → Executable file
View File

@ -1387,6 +1387,8 @@ def run(rank, world_size, args):
train_cuts, train_cuts,
do_normalize=params.do_normalize, do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict, sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
) )
valid_cuts = librispeech.dev_clean_cuts() valid_cuts = librispeech.dev_clean_cuts()
@ -1395,6 +1397,8 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders( valid_dl = librispeech.valid_dataloaders(
valid_cuts, valid_cuts,
do_normalize=params.do_normalize, do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
) )
if params.sanity_check and not params.print_diagnostics: if params.sanity_check and not params.print_diagnostics:

7
egs/librispeech/SSL/zipformer/pretrain.py Normal file → Executable file
View File

@ -41,7 +41,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
import argparse import argparse
import copy import copy
import logging import logging
import sys
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -594,7 +593,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-keep-size", "--max-keep-size",
type=int, type=int,
default=sys.maxsize, default=320000,
help="exclude sample longer than this.", help="exclude sample longer than this.",
) )
@ -1218,6 +1217,8 @@ def run(rank, world_size, args):
num_classes=params.num_classes, num_classes=params.num_classes,
do_normalize=params.do_normalize, do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict, sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
) )
valid_cuts = librispeech.dev_clean_cuts() valid_cuts = librispeech.dev_clean_cuts()
@ -1233,6 +1234,8 @@ def run(rank, world_size, args):
pad_audio=False, pad_audio=False,
num_classes=params.num_classes, num_classes=params.num_classes,
do_normalize=params.do_normalize, do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
) )
if params.sanity_check and not params.print_diagnostics: if params.sanity_check and not params.print_diagnostics: