mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Remove DDP
This commit is contained in:
parent
779589a2de
commit
25d540a758
@ -312,6 +312,5 @@ def get_dataloader(
|
||||
collate_fn=collate_fn,
|
||||
sampler=sampler,
|
||||
shuffle=sampler is None,
|
||||
num_workers=2,
|
||||
)
|
||||
return dataloader
|
||||
|
@ -19,20 +19,14 @@
|
||||
Usage:
|
||||
./rnn_lm/train.py \
|
||||
--start-epoch 0 \
|
||||
--num-epochs 20 \
|
||||
--batch-size 200 \
|
||||
--world-size 2 \
|
||||
--num-epochs 1 \
|
||||
--use-fp16 0 \
|
||||
--embedding-dim 800 \
|
||||
--hidden-dim 200 \
|
||||
--num-layers 2\
|
||||
--batch-size 400
|
||||
|
||||
If you want to use DDP training, e.g., a single node with 4 GPUs,
|
||||
use:
|
||||
|
||||
python -m torch.distributed.launch \
|
||||
--use_env \
|
||||
--nproc_per_node 4 \
|
||||
./rnn_lm/train.py \
|
||||
--use-ddp-launch true \
|
||||
--start-epoch 0 \
|
||||
--num-epochs 10 \
|
||||
--batch-size 200
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -46,29 +40,18 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from dataset import get_dataloader
|
||||
from lhotse.utils import fix_random_seed
|
||||
from rnn_lm.dataset import get_dataloader
|
||||
from rnn_lm.model import RnnLmModel
|
||||
from model import RnnLmModel
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import (
|
||||
cleanup_dist,
|
||||
get_local_rank,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
setup_dist,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_env_info,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -186,6 +169,22 @@ def get_parser():
|
||||
help="Number of RNN layers the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tie-weights",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True share the weights between the input embedding layer and the
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -513,23 +512,13 @@ def run(rank, world_size, args):
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
is_distributed = world_size > 1
|
||||
|
||||
if params.use_ddp_launch:
|
||||
local_rank = get_local_rank()
|
||||
else:
|
||||
local_rank = rank
|
||||
fix_random_seed(params.seed)
|
||||
if is_distributed:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
logging.warning(
|
||||
f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}"
|
||||
)
|
||||
|
||||
fix_random_seed(42)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port, params.use_ddp_launch)
|
||||
|
||||
setup_logger(
|
||||
f"{params.exp_dir}/log/log-train", rank=rank, world_size=world_size
|
||||
)
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
logging.info(params)
|
||||
|
||||
@ -540,9 +529,9 @@ def run(rank, world_size, args):
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", local_rank)
|
||||
device = torch.device("cuda", rank)
|
||||
|
||||
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
logging.info("About to create model")
|
||||
model = RnnLmModel(
|
||||
@ -555,8 +544,8 @@ def run(rank, world_size, args):
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[local_rank])
|
||||
if is_distributed:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
|
||||
model.device = device
|
||||
|
||||
@ -572,20 +561,20 @@ def run(rank, world_size, args):
|
||||
logging.info(f"Loading LM training data from {params.lm_data}")
|
||||
train_dl = get_dataloader(
|
||||
filename=params.lm_data,
|
||||
is_distributed=world_size > 1,
|
||||
is_distributed=is_distributed,
|
||||
params=params,
|
||||
)
|
||||
|
||||
logging.info(f"Loading LM validation data from {params.lm_data_valid}")
|
||||
valid_dl = get_dataloader(
|
||||
filename=params.lm_data_valid,
|
||||
is_distributed=world_size > 1,
|
||||
is_distributed=is_distributed,
|
||||
params=params,
|
||||
)
|
||||
|
||||
# Note: No learning rate scheduler is used here
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
if world_size > 1:
|
||||
if is_distributed:
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
@ -609,7 +598,7 @@ def run(rank, world_size, args):
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
if is_distributed:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
@ -619,17 +608,6 @@ def main():
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
if args.use_ddp_launch:
|
||||
# for torch.distributed.lanunch
|
||||
rank = get_rank()
|
||||
world_size = get_world_size()
|
||||
print(f"rank: {rank}, world_size: {world_size}")
|
||||
# This following is a hack as the default log level
|
||||
# is warning
|
||||
logging.info = logging.warning
|
||||
run(rank=rank, world_size=world_size, args=args)
|
||||
return
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
|
Loading…
x
Reference in New Issue
Block a user