Fix for multi-gpu

This commit is contained in:
yifanyeung 2024-05-04 17:36:33 +08:00
parent c08fe48603
commit 322baa2593
6 changed files with 54 additions and 3 deletions

View File

@ -132,6 +132,8 @@ class LibriSpeechAsrDataModule:
cuts_train: CutSet,
do_normalize: bool,
sampler_state_dict: Optional[Dict[str, Any]] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
"""
Args:
@ -151,6 +153,8 @@ class LibriSpeechAsrDataModule:
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
)
else:
logging.info("Using SimpleCutSampler.")
@ -158,6 +162,8 @@ class LibriSpeechAsrDataModule:
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
world_size=world_size,
rank=rank,
)
logging.info("About to create train dataloader")
@ -181,13 +187,21 @@ class LibriSpeechAsrDataModule:
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet, do_normalize: bool) -> DataLoader:
def valid_dataloaders(
self,
cuts_valid: CutSet,
do_normalize: bool,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
logging.info("About to create dev dataset")
validate = HubertAsrDataset(do_normalize=do_normalize)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
world_size=world_size,
rank=rank,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(

View File

@ -144,6 +144,8 @@ class LibriSpeechDataModule:
num_classes: list = [504],
do_normalize: bool = True,
sampler_state_dict: Optional[Dict[str, Any]] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
"""
Args:
@ -171,6 +173,8 @@ class LibriSpeechDataModule:
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
)
else:
logging.info("Using SimpleCutSampler.")
@ -178,6 +182,8 @@ class LibriSpeechDataModule:
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
world_size=world_size,
rank=rank,
)
logging.info("About to create train dataloader")
@ -211,6 +217,8 @@ class LibriSpeechDataModule:
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
logging.info("About to create dev dataset")
validate = HubertDataset(
@ -226,6 +234,8 @@ class LibriSpeechDataModule:
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
world_size=world_size,
rank=rank,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(

20
egs/librispeech/SSL/pretrain.sh Executable file
View File

@ -0,0 +1,20 @@
./zipformer/pretrain.py \
--world-size 8 \
--num-epochs 300 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp_pretrain \
--full-libri 1 \
--max-duration 600 \
--accum-grad 1 \
--do-normalize 0 \
--mask-prob 0.8 \
--dropout-input 0.1 \
--dropout-features 0.1 \
--feature-grad-mult 0.1 \
--untie-final-proj 1 \
--num-encoder-layers 2,2,3,4,3,2 \
--feedforward-dim 512,768,1024,1536,1024,768 \
--encoder-dim 192,256,448,768,448,192 \
--encoder-unmasked-dim 192,192,256,256,256,192 \
--base-lr 0.045

0
egs/librispeech/SSL/zipformer/decode.py Normal file → Executable file
View File

4
egs/librispeech/SSL/zipformer/finetune.py Normal file → Executable file
View File

@ -1387,6 +1387,8 @@ def run(rank, world_size, args):
train_cuts,
do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
)
valid_cuts = librispeech.dev_clean_cuts()
@ -1395,6 +1397,8 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders(
valid_cuts,
do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
)
if params.sanity_check and not params.print_diagnostics:

7
egs/librispeech/SSL/zipformer/pretrain.py Normal file → Executable file
View File

@ -41,7 +41,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
import argparse
import copy
import logging
import sys
import warnings
from pathlib import Path
from shutil import copyfile
@ -594,7 +593,7 @@ def get_parser():
parser.add_argument(
"--max-keep-size",
type=int,
default=sys.maxsize,
default=320000,
help="exclude sample longer than this.",
)
@ -1218,6 +1217,8 @@ def run(rank, world_size, args):
num_classes=params.num_classes,
do_normalize=params.do_normalize,
sampler_state_dict=sampler_state_dict,
world_size=world_size,
rank=rank,
)
valid_cuts = librispeech.dev_clean_cuts()
@ -1233,6 +1234,8 @@ def run(rank, world_size, args):
pad_audio=False,
num_classes=params.num_classes,
do_normalize=params.do_normalize,
world_size=world_size,
rank=rank,
)
if params.sanity_check and not params.print_diagnostics: