Remove DDP

This commit is contained in:
Erwan 2022-06-15 08:57:08 +02:00
parent 779589a2de
commit 25d540a758
2 changed files with 41 additions and 64 deletions

View File

@ -312,6 +312,5 @@ def get_dataloader(
collate_fn=collate_fn,
sampler=sampler,
shuffle=sampler is None,
num_workers=2,
)
return dataloader

View File

@ -19,20 +19,14 @@
Usage:
./rnn_lm/train.py \
--start-epoch 0 \
--num-epochs 20 \
--batch-size 200 \
--world-size 2 \
--num-epochs 1 \
--use-fp16 0 \
--embedding-dim 800 \
--hidden-dim 200 \
--num-layers 2\
--batch-size 400
If you want to use DDP training, e.g., a single node with 4 GPUs,
use:
python -m torch.distributed.launch \
--use_env \
--nproc_per_node 4 \
./rnn_lm/train.py \
--use-ddp-launch true \
--start-epoch 0 \
--num-epochs 10 \
--batch-size 200
"""
import argparse
@ -46,29 +40,18 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from dataset import get_dataloader
from lhotse.utils import fix_random_seed
from rnn_lm.dataset import get_dataloader
from rnn_lm.model import RnnLmModel
from model import RnnLmModel
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import (
cleanup_dist,
get_local_rank,
get_rank,
get_world_size,
setup_dist,
)
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_env_info,
setup_logger,
str2bool,
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
def get_parser():
@ -186,6 +169,22 @@ def get_parser():
help="Number of RNN layers the model",
)
parser.add_argument(
"--tie-weights",
type=str2bool,
default=False,
help="""True share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
@ -513,23 +512,13 @@ def run(rank, world_size, args):
"""
params = get_params()
params.update(vars(args))
is_distributed = world_size > 1
if params.use_ddp_launch:
local_rank = get_local_rank()
else:
local_rank = rank
fix_random_seed(params.seed)
if is_distributed:
setup_dist(rank, world_size, params.master_port)
logging.warning(
f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}"
)
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port, params.use_ddp_launch)
setup_logger(
f"{params.exp_dir}/log/log-train", rank=rank, world_size=world_size
)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
@ -540,9 +529,9 @@ def run(rank, world_size, args):
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", local_rank)
device = torch.device("cuda", rank)
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
logging.info(f"Device: {device}")
logging.info("About to create model")
model = RnnLmModel(
@ -555,8 +544,8 @@ def run(rank, world_size, args):
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[local_rank])
if is_distributed:
model = DDP(model, device_ids=[rank])
model.device = device
@ -572,20 +561,20 @@ def run(rank, world_size, args):
logging.info(f"Loading LM training data from {params.lm_data}")
train_dl = get_dataloader(
filename=params.lm_data,
is_distributed=world_size > 1,
is_distributed=is_distributed,
params=params,
)
logging.info(f"Loading LM validation data from {params.lm_data_valid}")
valid_dl = get_dataloader(
filename=params.lm_data_valid,
is_distributed=world_size > 1,
is_distributed=is_distributed,
params=params,
)
# Note: No learning rate scheduler is used here
for epoch in range(params.start_epoch, params.num_epochs):
if world_size > 1:
if is_distributed:
train_dl.sampler.set_epoch(epoch)
params.cur_epoch = epoch
@ -609,7 +598,7 @@ def run(rank, world_size, args):
logging.info("Done!")
if world_size > 1:
if is_distributed:
torch.distributed.barrier()
cleanup_dist()
@ -619,17 +608,6 @@ def main():
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
if args.use_ddp_launch:
# for torch.distributed.lanunch
rank = get_rank()
world_size = get_world_size()
print(f"rank: {rank}, world_size: {world_size}")
# This following is a hack as the default log level
# is warning
logging.info = logging.warning
run(rank=rank, world_size=world_size, args=args)
return
world_size = args.world_size
assert world_size >= 1
if world_size > 1: