mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Remove DDP
This commit is contained in:
parent
779589a2de
commit
25d540a758
@ -312,6 +312,5 @@ def get_dataloader(
|
|||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
shuffle=sampler is None,
|
shuffle=sampler is None,
|
||||||
num_workers=2,
|
|
||||||
)
|
)
|
||||||
return dataloader
|
return dataloader
|
||||||
|
@ -19,20 +19,14 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./rnn_lm/train.py \
|
./rnn_lm/train.py \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--num-epochs 20 \
|
--world-size 2 \
|
||||||
--batch-size 200 \
|
--num-epochs 1 \
|
||||||
|
--use-fp16 0 \
|
||||||
|
--embedding-dim 800 \
|
||||||
|
--hidden-dim 200 \
|
||||||
|
--num-layers 2\
|
||||||
|
--batch-size 400
|
||||||
|
|
||||||
If you want to use DDP training, e.g., a single node with 4 GPUs,
|
|
||||||
use:
|
|
||||||
|
|
||||||
python -m torch.distributed.launch \
|
|
||||||
--use_env \
|
|
||||||
--nproc_per_node 4 \
|
|
||||||
./rnn_lm/train.py \
|
|
||||||
--use-ddp-launch true \
|
|
||||||
--start-epoch 0 \
|
|
||||||
--num-epochs 10 \
|
|
||||||
--batch-size 200
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -46,29 +40,18 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
from dataset import get_dataloader
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from rnn_lm.dataset import get_dataloader
|
from model import RnnLmModel
|
||||||
from rnn_lm.model import RnnLmModel
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dist import (
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
cleanup_dist,
|
from icefall.env import get_env_info
|
||||||
get_local_rank,
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
get_rank,
|
|
||||||
get_world_size,
|
|
||||||
setup_dist,
|
|
||||||
)
|
|
||||||
from icefall.utils import (
|
|
||||||
AttributeDict,
|
|
||||||
MetricsTracker,
|
|
||||||
get_env_info,
|
|
||||||
setup_logger,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -186,6 +169,22 @@ def get_parser():
|
|||||||
help="Number of RNN layers the model",
|
help="Number of RNN layers the model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tie-weights",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True share the weights between the input embedding layer and the
|
||||||
|
last output linear layer
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -513,23 +512,13 @@ def run(rank, world_size, args):
|
|||||||
"""
|
"""
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
is_distributed = world_size > 1
|
||||||
|
|
||||||
if params.use_ddp_launch:
|
fix_random_seed(params.seed)
|
||||||
local_rank = get_local_rank()
|
if is_distributed:
|
||||||
else:
|
setup_dist(rank, world_size, params.master_port)
|
||||||
local_rank = rank
|
|
||||||
|
|
||||||
logging.warning(
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
f"rank: {rank}, world_size: {world_size}, local_rank: {local_rank}"
|
|
||||||
)
|
|
||||||
|
|
||||||
fix_random_seed(42)
|
|
||||||
if world_size > 1:
|
|
||||||
setup_dist(rank, world_size, params.master_port, params.use_ddp_launch)
|
|
||||||
|
|
||||||
setup_logger(
|
|
||||||
f"{params.exp_dir}/log/log-train", rank=rank, world_size=world_size
|
|
||||||
)
|
|
||||||
logging.info("Training started")
|
logging.info("Training started")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -540,9 +529,9 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", local_rank)
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
logging.info(f"Device: {device}, rank: {rank}, local_rank: {local_rank}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = RnnLmModel(
|
model = RnnLmModel(
|
||||||
@ -555,8 +544,8 @@ def run(rank, world_size, args):
|
|||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if is_distributed:
|
||||||
model = DDP(model, device_ids=[local_rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
|
|
||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
@ -572,20 +561,20 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"Loading LM training data from {params.lm_data}")
|
logging.info(f"Loading LM training data from {params.lm_data}")
|
||||||
train_dl = get_dataloader(
|
train_dl = get_dataloader(
|
||||||
filename=params.lm_data,
|
filename=params.lm_data,
|
||||||
is_distributed=world_size > 1,
|
is_distributed=is_distributed,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info(f"Loading LM validation data from {params.lm_data_valid}")
|
logging.info(f"Loading LM validation data from {params.lm_data_valid}")
|
||||||
valid_dl = get_dataloader(
|
valid_dl = get_dataloader(
|
||||||
filename=params.lm_data_valid,
|
filename=params.lm_data_valid,
|
||||||
is_distributed=world_size > 1,
|
is_distributed=is_distributed,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: No learning rate scheduler is used here
|
# Note: No learning rate scheduler is used here
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
if world_size > 1:
|
if is_distributed:
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
params.cur_epoch = epoch
|
params.cur_epoch = epoch
|
||||||
@ -609,7 +598,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
if world_size > 1:
|
if is_distributed:
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
cleanup_dist()
|
cleanup_dist()
|
||||||
|
|
||||||
@ -619,17 +608,6 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
if args.use_ddp_launch:
|
|
||||||
# for torch.distributed.lanunch
|
|
||||||
rank = get_rank()
|
|
||||||
world_size = get_world_size()
|
|
||||||
print(f"rank: {rank}, world_size: {world_size}")
|
|
||||||
# This following is a hack as the default log level
|
|
||||||
# is warning
|
|
||||||
logging.info = logging.warning
|
|
||||||
run(rank=rank, world_size=world_size, args=args)
|
|
||||||
return
|
|
||||||
|
|
||||||
world_size = args.world_size
|
world_size = args.world_size
|
||||||
assert world_size >= 1
|
assert world_size >= 1
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user