diff --git a/egs/librispeech/ASR/rnn_lm/dataset.py b/egs/librispeech/ASR/rnn_lm/dataset.py index 0fc5b3d84..2da9539d1 100644 --- a/egs/librispeech/ASR/rnn_lm/dataset.py +++ b/egs/librispeech/ASR/rnn_lm/dataset.py @@ -312,6 +312,5 @@ def get_dataloader( collate_fn=collate_fn, sampler=sampler, shuffle=sampler is None, - num_workers=2, ) return dataloader diff --git a/egs/librispeech/ASR/rnn_lm/train.py b/egs/librispeech/ASR/rnn_lm/train.py index a6640948a..27fc237b4 100755 --- a/egs/librispeech/ASR/rnn_lm/train.py +++ b/egs/librispeech/ASR/rnn_lm/train.py @@ -19,20 +19,14 @@ Usage: ./rnn_lm/train.py \ --start-epoch 0 \ - --num-epochs 20 \ - --batch-size 200 \ + --world-size 2 \ + --num-epochs 1 \ + --use-fp16 0 \ + --embedding-dim 800 \ + --hidden-dim 200 \ + --num-layers 2\ + --batch-size 400 -If you want to use DDP training, e.g., a single node with 4 GPUs, -use: - - python -m torch.distributed.launch \ - --use_env \ - --nproc_per_node 4 \ - ./rnn_lm/train.py \ - --use-ddp-launch true \ - --start-epoch 0 \ - --num-epochs 10 \ - --batch-size 200 """ import argparse @@ -46,29 +40,18 @@ import torch import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim +from dataset import get_dataloader from lhotse.utils import fix_random_seed -from rnn_lm.dataset import get_dataloader -from rnn_lm.model import RnnLmModel +from model import RnnLmModel from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import ( - cleanup_dist, - get_local_rank, - get_rank, - get_world_size, - setup_dist, -) -from icefall.utils import ( - AttributeDict, - MetricsTracker, - get_env_info, - setup_logger, - str2bool, -) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool def get_parser(): @@ -186,6 +169,22 @@ def get_parser(): help="Number of RNN layers the model", ) + parser.add_argument( + "--tie-weights", + type=str2bool, + default=False, + help="""True share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -513,23 +512,13 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) + is_distributed = world_size > 1 - if params.use_ddp_launch: - local_rank = get_local_rank() - else: - local_rank = rank + fix_random_seed(params.seed) + if is_distributed: + setup_dist(rank, world_size, params.master_port) - logging.warning( - f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}" - ) - - fix_random_seed(42) - if world_size > 1: - setup_dist(rank, world_size, params.master_port, params.use_ddp_launch) - - setup_logger( - f"{params.exp_dir}/log/log-train", rank=rank, world_size=world_size - ) + setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") logging.info(params) @@ -540,9 +529,9 @@ def run(rank, world_size, args): device = torch.device("cpu") if torch.cuda.is_available(): - device = torch.device("cuda", local_rank) + device = torch.device("cuda", rank) - logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}") + logging.info(f"Device: {device}") logging.info("About to create model") model = RnnLmModel( @@ -555,8 +544,8 @@ def run(rank, world_size, args): checkpoints = load_checkpoint_if_available(params=params, model=model) model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[local_rank]) + if is_distributed: + model = DDP(model, device_ids=[rank]) model.device = device @@ -572,20 +561,20 @@ def run(rank, world_size, args): logging.info(f"Loading LM training data from {params.lm_data}") train_dl = get_dataloader( filename=params.lm_data, - is_distributed=world_size > 1, + is_distributed=is_distributed, params=params, ) logging.info(f"Loading LM validation data from {params.lm_data_valid}") valid_dl = get_dataloader( filename=params.lm_data_valid, - is_distributed=world_size > 1, + is_distributed=is_distributed, params=params, ) # Note: No learning rate scheduler is used here for epoch in range(params.start_epoch, params.num_epochs): - if world_size > 1: + if is_distributed: train_dl.sampler.set_epoch(epoch) params.cur_epoch = epoch @@ -609,7 +598,7 @@ def run(rank, world_size, args): logging.info("Done!") - if world_size > 1: + if is_distributed: torch.distributed.barrier() cleanup_dist() @@ -619,17 +608,6 @@ def main(): args = parser.parse_args() args.exp_dir = Path(args.exp_dir) - if args.use_ddp_launch: - # for torch.distributed.lanunch - rank = get_rank() - world_size = get_world_size() - print(f"rank: {rank}, world_size: {world_size}") - # This following is a hack as the default log level - # is warning - logging.info = logging.warning - run(rank=rank, world_size=world_size, args=args) - return - world_size = args.world_size assert world_size >= 1 if world_size > 1: