Remove DDP

This commit is contained in:
Erwan 2022-06-15 08:57:08 +02:00
parent 779589a2de
commit 25d540a758
2 changed files with 41 additions and 64 deletions

View File

@ -312,6 +312,5 @@ def get_dataloader(
collate_fn=collate_fn, collate_fn=collate_fn,
sampler=sampler, sampler=sampler,
shuffle=sampler is None, shuffle=sampler is None,
num_workers=2,
) )
return dataloader return dataloader

View File

@ -19,20 +19,14 @@
Usage: Usage:
./rnn_lm/train.py \ ./rnn_lm/train.py \
--start-epoch 0 \ --start-epoch 0 \
--num-epochs 20 \ --world-size 2 \
--batch-size 200 \ --num-epochs 1 \
--use-fp16 0 \
--embedding-dim 800 \
--hidden-dim 200 \
--num-layers 2\
--batch-size 400
If you want to use DDP training, e.g., a single node with 4 GPUs,
use:
python -m torch.distributed.launch \
--use_env \
--nproc_per_node 4 \
./rnn_lm/train.py \
--use-ddp-launch true \
--start-epoch 0 \
--num-epochs 10 \
--batch-size 200
""" """
import argparse import argparse
@ -46,29 +40,18 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from dataset import get_dataloader
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from rnn_lm.dataset import get_dataloader from model import RnnLmModel
from rnn_lm.model import RnnLmModel
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import ( from icefall.dist import cleanup_dist, setup_dist
cleanup_dist, from icefall.env import get_env_info
get_local_rank, from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
get_rank,
get_world_size,
setup_dist,
)
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_env_info,
setup_logger,
str2bool,
)
def get_parser(): def get_parser():
@ -186,6 +169,22 @@ def get_parser():
help="Number of RNN layers the model", help="Number of RNN layers the model",
) )
parser.add_argument(
"--tie-weights",
type=str2bool,
default=False,
help="""True share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser return parser
@ -513,23 +512,13 @@ def run(rank, world_size, args):
""" """
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
is_distributed = world_size > 1
if params.use_ddp_launch: fix_random_seed(params.seed)
local_rank = get_local_rank() if is_distributed:
else: setup_dist(rank, world_size, params.master_port)
local_rank = rank
logging.warning( setup_logger(f"{params.exp_dir}/log/log-train")
f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}"
)
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port, params.use_ddp_launch)
setup_logger(
f"{params.exp_dir}/log/log-train", rank=rank, world_size=world_size
)
logging.info("Training started") logging.info("Training started")
logging.info(params) logging.info(params)
@ -540,9 +529,9 @@ def run(rank, world_size, args):
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", local_rank) device = torch.device("cuda", rank)
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}") logging.info(f"Device: {device}")
logging.info("About to create model") logging.info("About to create model")
model = RnnLmModel( model = RnnLmModel(
@ -555,8 +544,8 @@ def run(rank, world_size, args):
checkpoints = load_checkpoint_if_available(params=params, model=model) checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device) model.to(device)
if world_size > 1: if is_distributed:
model = DDP(model, device_ids=[local_rank]) model = DDP(model, device_ids=[rank])
model.device = device model.device = device
@ -572,20 +561,20 @@ def run(rank, world_size, args):
logging.info(f"Loading LM training data from {params.lm_data}") logging.info(f"Loading LM training data from {params.lm_data}")
train_dl = get_dataloader( train_dl = get_dataloader(
filename=params.lm_data, filename=params.lm_data,
is_distributed=world_size > 1, is_distributed=is_distributed,
params=params, params=params,
) )
logging.info(f"Loading LM validation data from {params.lm_data_valid}") logging.info(f"Loading LM validation data from {params.lm_data_valid}")
valid_dl = get_dataloader( valid_dl = get_dataloader(
filename=params.lm_data_valid, filename=params.lm_data_valid,
is_distributed=world_size > 1, is_distributed=is_distributed,
params=params, params=params,
) )
# Note: No learning rate scheduler is used here # Note: No learning rate scheduler is used here
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
if world_size > 1: if is_distributed:
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
params.cur_epoch = epoch params.cur_epoch = epoch
@ -609,7 +598,7 @@ def run(rank, world_size, args):
logging.info("Done!") logging.info("Done!")
if world_size > 1: if is_distributed:
torch.distributed.barrier() torch.distributed.barrier()
cleanup_dist() cleanup_dist()
@ -619,17 +608,6 @@ def main():
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
if args.use_ddp_launch:
# for torch.distributed.lanunch
rank = get_rank()
world_size = get_world_size()
print(f"rank: {rank}, world_size: {world_size}")
# This following is a hack as the default log level
# is warning
logging.info = logging.warning
run(rank=rank, world_size=world_size, args=args)
return
world_size = args.world_size world_size = args.world_size
assert world_size >= 1 assert world_size >= 1
if world_size > 1: if world_size > 1: