add tqdm for getting a progress bar of the training

This commit is contained in:
songmeixu 2021-09-20 20:12:27 +08:00
parent a80e58e15d
commit 236b305698
2 changed files with 147 additions and 112 deletions

View File

@ -28,6 +28,7 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from tqdm import tqdm
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
@ -46,6 +47,7 @@ from icefall.utils import (
encode_supervisions,
setup_logger,
str2bool,
is_main_process,
)
@ -478,7 +480,17 @@ def train_one_epoch(
tot_frames = 0.0 # sum of frames over all batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl):
with tqdm(
train_dl,
initial=0,
total=train_dl.sampler.num_cuts,
leave=False,
unit="cut", # a data fragment concept comes from "lhotse"
dynamic_ncols=True,
disable=not is_main_process(),
) as t:
for batch_idx, batch in enumerate(t):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -514,12 +526,18 @@ def train_one_epoch(
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames
t.set_postfix(total_avg_loss=tot_avg_loss)
t.update(batch_size)
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"batch avg ctc loss \
{ctc_loss_cpu/params.train_frames:.4f}, "
f"batch avg att loss \
{att_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss \
{loss_cpu/params.train_frames:.4f}, "
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
@ -692,7 +710,7 @@ def run(rank, world_size, args):
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
if is_main_process():
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch

View File

@ -411,3 +411,20 @@ def write_error_stats(
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate)
def is_main_process():
"""Check if the current process is main.
- For DistributedDataParallel (DDP) mode:
The current is main process if its rank is 0, and its rank get from
os.environ["RANK"] (which already be set by DDP mode).
- For standard mode and others:
The os.environ["RANK"] is None, and
"""
if os.environ.get('RANK') is None:
return True
elif os.environ.get("RANK") == "0":
return True
else:
return False