mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Fix for multi-gpu
This commit is contained in:
parent
c08fe48603
commit
322baa2593
@ -132,6 +132,8 @@ class LibriSpeechAsrDataModule:
|
||||
cuts_train: CutSet,
|
||||
do_normalize: bool,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
world_size: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
@ -151,6 +153,8 @@ class LibriSpeechAsrDataModule:
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
@ -158,6 +162,8 @@ class LibriSpeechAsrDataModule:
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
@ -181,13 +187,21 @@ class LibriSpeechAsrDataModule:
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet, do_normalize: bool) -> DataLoader:
|
||||
def valid_dataloaders(
|
||||
self,
|
||||
cuts_valid: CutSet,
|
||||
do_normalize: bool,
|
||||
world_size: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
validate = HubertAsrDataset(do_normalize=do_normalize)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
|
@ -144,6 +144,8 @@ class LibriSpeechDataModule:
|
||||
num_classes: list = [504],
|
||||
do_normalize: bool = True,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
world_size: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
@ -171,6 +173,8 @@ class LibriSpeechDataModule:
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
@ -178,6 +182,8 @@ class LibriSpeechDataModule:
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
@ -211,6 +217,8 @@ class LibriSpeechDataModule:
|
||||
pad_audio: bool = False,
|
||||
num_classes: list = [504],
|
||||
do_normalize: bool = True,
|
||||
world_size: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
validate = HubertDataset(
|
||||
@ -226,6 +234,8 @@ class LibriSpeechDataModule:
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
|
20
egs/librispeech/SSL/pretrain.sh
Executable file
20
egs/librispeech/SSL/pretrain.sh
Executable file
@ -0,0 +1,20 @@
|
||||
./zipformer/pretrain.py \
|
||||
--world-size 8 \
|
||||
--num-epochs 300 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp_pretrain \
|
||||
--full-libri 1 \
|
||||
--max-duration 600 \
|
||||
--accum-grad 1 \
|
||||
--do-normalize 0 \
|
||||
--mask-prob 0.8 \
|
||||
--dropout-input 0.1 \
|
||||
--dropout-features 0.1 \
|
||||
--feature-grad-mult 0.1 \
|
||||
--untie-final-proj 1 \
|
||||
--num-encoder-layers 2,2,3,4,3,2 \
|
||||
--feedforward-dim 512,768,1024,1536,1024,768 \
|
||||
--encoder-dim 192,256,448,768,448,192 \
|
||||
--encoder-unmasked-dim 192,192,256,256,256,192 \
|
||||
--base-lr 0.045
|
0
egs/librispeech/SSL/zipformer/decode.py
Normal file → Executable file
0
egs/librispeech/SSL/zipformer/decode.py
Normal file → Executable file
4
egs/librispeech/SSL/zipformer/finetune.py
Normal file → Executable file
4
egs/librispeech/SSL/zipformer/finetune.py
Normal file → Executable file
@ -1387,6 +1387,8 @@ def run(rank, world_size, args):
|
||||
train_cuts,
|
||||
do_normalize=params.do_normalize,
|
||||
sampler_state_dict=sampler_state_dict,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
@ -1395,6 +1397,8 @@ def run(rank, world_size, args):
|
||||
valid_dl = librispeech.valid_dataloaders(
|
||||
valid_cuts,
|
||||
do_normalize=params.do_normalize,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.sanity_check and not params.print_diagnostics:
|
||||
|
7
egs/librispeech/SSL/zipformer/pretrain.py
Normal file → Executable file
7
egs/librispeech/SSL/zipformer/pretrain.py
Normal file → Executable file
@ -41,7 +41,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
@ -594,7 +593,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--max-keep-size",
|
||||
type=int,
|
||||
default=sys.maxsize,
|
||||
default=320000,
|
||||
help="exclude sample longer than this.",
|
||||
)
|
||||
|
||||
@ -1218,6 +1217,8 @@ def run(rank, world_size, args):
|
||||
num_classes=params.num_classes,
|
||||
do_normalize=params.do_normalize,
|
||||
sampler_state_dict=sampler_state_dict,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
@ -1233,6 +1234,8 @@ def run(rank, world_size, args):
|
||||
pad_audio=False,
|
||||
num_classes=params.num_classes,
|
||||
do_normalize=params.do_normalize,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.sanity_check and not params.print_diagnostics:
|
||||
|
Loading…
x
Reference in New Issue
Block a user