mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Start training LM for LibriSpeech.
This commit is contained in:
parent
3c65ee11f4
commit
774f6643cd
@ -42,8 +42,6 @@ dl_dir=$PWD/download
|
||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
5000
|
||||
2000
|
||||
1000
|
||||
500
|
||||
)
|
||||
|
||||
@ -288,7 +286,6 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
cat $f | cut -d " " -f 2-
|
||||
done > $out_dir/test.txt
|
||||
fi
|
||||
exit 0
|
||||
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
./local/prepare_lm_training_data.py \
|
||||
|
@ -21,14 +21,46 @@ import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
|
||||
def setup_dist(rank, world_size, master_port=None):
|
||||
def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False):
|
||||
"""
|
||||
rank and world_size are used only if use_ddp_launch is False.
|
||||
"""
|
||||
if "MASTER_ADDR" not in os.environ:
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
|
||||
if "MASTER_PORT" not in os.environ:
|
||||
os.environ["MASTER_PORT"] = (
|
||||
"12354" if master_port is None else str(master_port)
|
||||
)
|
||||
|
||||
if use_ddp_launch is False:
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
else:
|
||||
dist.init_process_group("nccl")
|
||||
|
||||
|
||||
def cleanup_dist():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if "WORLD_SIZE" in os.environ:
|
||||
return int(os.environ["WORLD_SIZE"])
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
return dist.get_world_size()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def get_rank():
|
||||
if "RANK" in os.environ:
|
||||
return int(os.environ["RANK"])
|
||||
elif dist.is_available() and dist.is_initialized():
|
||||
return dist.rank()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def get_local_rank():
|
||||
return int(os.environ.get("LOCAL_RANK", 0))
|
||||
|
@ -91,7 +91,11 @@ def str2bool(v):
|
||||
|
||||
|
||||
def setup_logger(
|
||||
log_filename: Pathlike, log_level: str = "info", use_console: bool = True
|
||||
log_filename: Pathlike,
|
||||
log_level: str = "info",
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
use_console: bool = True,
|
||||
) -> None:
|
||||
"""Setup log level.
|
||||
|
||||
@ -101,12 +105,16 @@ def setup_logger(
|
||||
log_level:
|
||||
The log level to use, e.g., "debug", "info", "warning", "error",
|
||||
"critical"
|
||||
rank:
|
||||
Rank of this node in DDP training.
|
||||
world_size:
|
||||
Number of nodes in DDP training.
|
||||
use_console:
|
||||
True to also print logs to console.
|
||||
"""
|
||||
now = datetime.now()
|
||||
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
if world_size > 1:
|
||||
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
||||
log_filename = f"{log_filename}-{date_time}-{rank}"
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user