mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Fix for multi-gpu
This commit is contained in:
parent
c08fe48603
commit
322baa2593
@ -132,6 +132,8 @@ class LibriSpeechAsrDataModule:
|
|||||||
cuts_train: CutSet,
|
cuts_train: CutSet,
|
||||||
do_normalize: bool,
|
do_normalize: bool,
|
||||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -151,6 +153,8 @@ class LibriSpeechAsrDataModule:
|
|||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SimpleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
@ -158,6 +162,8 @@ class LibriSpeechAsrDataModule:
|
|||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
@ -181,13 +187,21 @@ class LibriSpeechAsrDataModule:
|
|||||||
|
|
||||||
return train_dl
|
return train_dl
|
||||||
|
|
||||||
def valid_dataloaders(self, cuts_valid: CutSet, do_normalize: bool) -> DataLoader:
|
def valid_dataloaders(
|
||||||
|
self,
|
||||||
|
cuts_valid: CutSet,
|
||||||
|
do_normalize: bool,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
|
) -> DataLoader:
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
validate = HubertAsrDataset(do_normalize=do_normalize)
|
validate = HubertAsrDataset(do_normalize=do_normalize)
|
||||||
valid_sampler = DynamicBucketingSampler(
|
valid_sampler = DynamicBucketingSampler(
|
||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create dev dataloader")
|
logging.info("About to create dev dataloader")
|
||||||
valid_dl = DataLoader(
|
valid_dl = DataLoader(
|
||||||
|
@ -144,6 +144,8 @@ class LibriSpeechDataModule:
|
|||||||
num_classes: list = [504],
|
num_classes: list = [504],
|
||||||
do_normalize: bool = True,
|
do_normalize: bool = True,
|
||||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -171,6 +173,8 @@ class LibriSpeechDataModule:
|
|||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SimpleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
@ -178,6 +182,8 @@ class LibriSpeechDataModule:
|
|||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
@ -211,6 +217,8 @@ class LibriSpeechDataModule:
|
|||||||
pad_audio: bool = False,
|
pad_audio: bool = False,
|
||||||
num_classes: list = [504],
|
num_classes: list = [504],
|
||||||
do_normalize: bool = True,
|
do_normalize: bool = True,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
rank: Optional[int] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
validate = HubertDataset(
|
validate = HubertDataset(
|
||||||
@ -226,6 +234,8 @@ class LibriSpeechDataModule:
|
|||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
logging.info("About to create dev dataloader")
|
logging.info("About to create dev dataloader")
|
||||||
valid_dl = DataLoader(
|
valid_dl = DataLoader(
|
||||||
|
20
egs/librispeech/SSL/pretrain.sh
Executable file
20
egs/librispeech/SSL/pretrain.sh
Executable file
@ -0,0 +1,20 @@
|
|||||||
|
./zipformer/pretrain.py \
|
||||||
|
--world-size 8 \
|
||||||
|
--num-epochs 300 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--exp-dir zipformer/exp_pretrain \
|
||||||
|
--full-libri 1 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--accum-grad 1 \
|
||||||
|
--do-normalize 0 \
|
||||||
|
--mask-prob 0.8 \
|
||||||
|
--dropout-input 0.1 \
|
||||||
|
--dropout-features 0.1 \
|
||||||
|
--feature-grad-mult 0.1 \
|
||||||
|
--untie-final-proj 1 \
|
||||||
|
--num-encoder-layers 2,2,3,4,3,2 \
|
||||||
|
--feedforward-dim 512,768,1024,1536,1024,768 \
|
||||||
|
--encoder-dim 192,256,448,768,448,192 \
|
||||||
|
--encoder-unmasked-dim 192,192,256,256,256,192 \
|
||||||
|
--base-lr 0.045
|
0
egs/librispeech/SSL/zipformer/decode.py
Normal file → Executable file
0
egs/librispeech/SSL/zipformer/decode.py
Normal file → Executable file
4
egs/librispeech/SSL/zipformer/finetune.py
Normal file → Executable file
4
egs/librispeech/SSL/zipformer/finetune.py
Normal file → Executable file
@ -1387,6 +1387,8 @@ def run(rank, world_size, args):
|
|||||||
train_cuts,
|
train_cuts,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
sampler_state_dict=sampler_state_dict,
|
sampler_state_dict=sampler_state_dict,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
@ -1395,6 +1397,8 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = librispeech.valid_dataloaders(
|
valid_dl = librispeech.valid_dataloaders(
|
||||||
valid_cuts,
|
valid_cuts,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.sanity_check and not params.print_diagnostics:
|
if params.sanity_check and not params.print_diagnostics:
|
||||||
|
7
egs/librispeech/SSL/zipformer/pretrain.py
Normal file → Executable file
7
egs/librispeech/SSL/zipformer/pretrain.py
Normal file → Executable file
@ -41,7 +41,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -594,7 +593,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-keep-size",
|
"--max-keep-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=sys.maxsize,
|
default=320000,
|
||||||
help="exclude sample longer than this.",
|
help="exclude sample longer than this.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1218,6 +1217,8 @@ def run(rank, world_size, args):
|
|||||||
num_classes=params.num_classes,
|
num_classes=params.num_classes,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
sampler_state_dict=sampler_state_dict,
|
sampler_state_dict=sampler_state_dict,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
@ -1233,6 +1234,8 @@ def run(rank, world_size, args):
|
|||||||
pad_audio=False,
|
pad_audio=False,
|
||||||
num_classes=params.num_classes,
|
num_classes=params.num_classes,
|
||||||
do_normalize=params.do_normalize,
|
do_normalize=params.do_normalize,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.sanity_check and not params.print_diagnostics:
|
if params.sanity_check and not params.print_diagnostics:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user