mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Start training LM for LibriSpeech.
This commit is contained in:
parent
3c65ee11f4
commit
774f6643cd
@ -42,8 +42,6 @@ dl_dir=$PWD/download
|
|||||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||||
vocab_sizes=(
|
vocab_sizes=(
|
||||||
5000
|
5000
|
||||||
2000
|
|
||||||
1000
|
|
||||||
500
|
500
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -288,7 +286,6 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
|||||||
cat $f | cut -d " " -f 2-
|
cat $f | cut -d " " -f 2-
|
||||||
done > $out_dir/test.txt
|
done > $out_dir/test.txt
|
||||||
fi
|
fi
|
||||||
exit 0
|
|
||||||
|
|
||||||
lang_dir=data/lang_bpe_${vocab_size}
|
lang_dir=data/lang_bpe_${vocab_size}
|
||||||
./local/prepare_lm_training_data.py \
|
./local/prepare_lm_training_data.py \
|
||||||
|
@ -21,14 +21,46 @@ import torch
|
|||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
|
|
||||||
|
|
||||||
def setup_dist(rank, world_size, master_port=None):
|
def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False):
|
||||||
os.environ["MASTER_ADDR"] = "localhost"
|
"""
|
||||||
os.environ["MASTER_PORT"] = (
|
rank and world_size are used only if use_ddp_launch is False.
|
||||||
"12354" if master_port is None else str(master_port)
|
"""
|
||||||
)
|
if "MASTER_ADDR" not in os.environ:
|
||||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
torch.cuda.set_device(rank)
|
|
||||||
|
if "MASTER_PORT" not in os.environ:
|
||||||
|
os.environ["MASTER_PORT"] = (
|
||||||
|
"12354" if master_port is None else str(master_port)
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_ddp_launch is False:
|
||||||
|
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
|
else:
|
||||||
|
dist.init_process_group("nccl")
|
||||||
|
|
||||||
|
|
||||||
def cleanup_dist():
|
def cleanup_dist():
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
if "WORLD_SIZE" in os.environ:
|
||||||
|
return int(os.environ["WORLD_SIZE"])
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
return dist.get_world_size()
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
if "RANK" in os.environ:
|
||||||
|
return int(os.environ["RANK"])
|
||||||
|
elif dist.is_available() and dist.is_initialized():
|
||||||
|
return dist.rank()
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_rank():
|
||||||
|
return int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
@ -91,7 +91,11 @@ def str2bool(v):
|
|||||||
|
|
||||||
|
|
||||||
def setup_logger(
|
def setup_logger(
|
||||||
log_filename: Pathlike, log_level: str = "info", use_console: bool = True
|
log_filename: Pathlike,
|
||||||
|
log_level: str = "info",
|
||||||
|
rank: int = 0,
|
||||||
|
world_size: int = 1,
|
||||||
|
use_console: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Setup log level.
|
"""Setup log level.
|
||||||
|
|
||||||
@ -101,12 +105,16 @@ def setup_logger(
|
|||||||
log_level:
|
log_level:
|
||||||
The log level to use, e.g., "debug", "info", "warning", "error",
|
The log level to use, e.g., "debug", "info", "warning", "error",
|
||||||
"critical"
|
"critical"
|
||||||
|
rank:
|
||||||
|
Rank of this node in DDP training.
|
||||||
|
world_size:
|
||||||
|
Number of nodes in DDP training.
|
||||||
|
use_console:
|
||||||
|
True to also print logs to console.
|
||||||
"""
|
"""
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||||
if dist.is_available() and dist.is_initialized():
|
if world_size > 1:
|
||||||
world_size = dist.get_world_size()
|
|
||||||
rank = dist.get_rank()
|
|
||||||
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
||||||
log_filename = f"{log_filename}-{date_time}-{rank}"
|
log_filename = f"{log_filename}-{date_time}-{rank}"
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user