add tqdm for getting a progress bar of the training

This commit is contained in:
songmeixu 2021-09-20 20:12:27 +08:00
parent a80e58e15d
commit 236b305698
2 changed files with 147 additions and 112 deletions

View File

@ -28,6 +28,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
@ -46,6 +47,7 @@ from icefall.utils import (
encode_supervisions, encode_supervisions,
setup_logger, setup_logger,
str2bool, str2bool,
is_main_process,
) )
@ -478,7 +480,17 @@ def train_one_epoch(
tot_frames = 0.0 # sum of frames over all batches tot_frames = 0.0 # sum of frames over all batches
params.tot_loss = 0.0 params.tot_loss = 0.0
params.tot_frames = 0.0 params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl):
with tqdm(
train_dl,
initial=0,
total=train_dl.sampler.num_cuts,
leave=False,
unit="cut", # a data fragment concept comes from "lhotse"
dynamic_ncols=True,
disable=not is_main_process(),
) as t:
for batch_idx, batch in enumerate(t):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -514,12 +526,18 @@ def train_one_epoch(
tot_avg_ctc_loss = tot_ctc_loss / tot_frames tot_avg_ctc_loss = tot_ctc_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames tot_avg_att_loss = tot_att_loss / tot_frames
t.set_postfix(total_avg_loss=tot_avg_loss)
t.update(batch_size)
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, " f"batch avg ctc loss \
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " {ctc_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, " f"batch avg att loss \
{att_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss \
{loss_cpu/params.train_frames:.4f}, "
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, " f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, " f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, " f"total avg loss: {tot_avg_loss:.4f}, "
@ -692,7 +710,7 @@ def run(rank, world_size, args):
) )
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if is_main_process():
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch params.cur_epoch = epoch

View File

@ -411,3 +411,20 @@ def write_error_stats(
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate) return float(tot_err_rate)
def is_main_process():
"""Check if the current process is main.
- For DistributedDataParallel (DDP) mode:
The current is main process if its rank is 0, and its rank get from
os.environ["RANK"] (which already be set by DDP mode).
- For standard mode and others:
The os.environ["RANK"] is None, and
"""
if os.environ.get('RANK') is None:
return True
elif os.environ.get("RANK") == "0":
return True
else:
return False