Start training LM for LibriSpeech.

This commit is contained in:
Fangjun Kuang 2021-11-23 11:58:25 +08:00
parent 3c65ee11f4
commit 774f6643cd
3 changed files with 51 additions and 14 deletions

View File

@ -42,8 +42,6 @@ dl_dir=$PWD/download
# data/lang_bpe_yyy if the array contains xxx, yyy # data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=( vocab_sizes=(
5000 5000
2000
1000
500 500
) )
@ -288,7 +286,6 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
cat $f | cut -d " " -f 2- cat $f | cut -d " " -f 2-
done > $out_dir/test.txt done > $out_dir/test.txt
fi fi
exit 0
lang_dir=data/lang_bpe_${vocab_size} lang_dir=data/lang_bpe_${vocab_size}
./local/prepare_lm_training_data.py \ ./local/prepare_lm_training_data.py \

View File

@ -21,14 +21,46 @@ import torch
from torch import distributed as dist from torch import distributed as dist
def setup_dist(rank, world_size, master_port=None): def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False):
os.environ["MASTER_ADDR"] = "localhost" """
os.environ["MASTER_PORT"] = ( rank and world_size are used only if use_ddp_launch is False.
"12354" if master_port is None else str(master_port) """
) if "MASTER_ADDR" not in os.environ:
dist.init_process_group("nccl", rank=rank, world_size=world_size) os.environ["MASTER_ADDR"] = "localhost"
torch.cuda.set_device(rank)
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = (
"12354" if master_port is None else str(master_port)
)
if use_ddp_launch is False:
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
else:
dist.init_process_group("nccl")
def cleanup_dist(): def cleanup_dist():
dist.destroy_process_group() dist.destroy_process_group()
def get_world_size():
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"])
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1
def get_rank():
if "RANK" in os.environ:
return int(os.environ["RANK"])
elif dist.is_available() and dist.is_initialized():
return dist.rank()
else:
return 1
def get_local_rank():
return int(os.environ.get("LOCAL_RANK", 0))

View File

@ -91,7 +91,11 @@ def str2bool(v):
def setup_logger( def setup_logger(
log_filename: Pathlike, log_level: str = "info", use_console: bool = True log_filename: Pathlike,
log_level: str = "info",
rank: int = 0,
world_size: int = 1,
use_console: bool = True,
) -> None: ) -> None:
"""Setup log level. """Setup log level.
@ -101,12 +105,16 @@ def setup_logger(
log_level: log_level:
The log level to use, e.g., "debug", "info", "warning", "error", The log level to use, e.g., "debug", "info", "warning", "error",
"critical" "critical"
rank:
Rank of this node in DDP training.
world_size:
Number of nodes in DDP training.
use_console:
True to also print logs to console.
""" """
now = datetime.now() now = datetime.now()
date_time = now.strftime("%Y-%m-%d-%H-%M-%S") date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
if dist.is_available() and dist.is_initialized(): if world_size > 1:
world_size = dist.get_world_size()
rank = dist.get_rank()
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
log_filename = f"{log_filename}-{date_time}-{rank}" log_filename = f"{log_filename}-{date_time}-{rank}"
else: else: