mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-15 20:22:42 +00:00
add tqdm for getting a progress bar of the training
This commit is contained in:
parent
a80e58e15d
commit
236b305698
@ -28,6 +28,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
@ -46,6 +47,7 @@ from icefall.utils import (
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
is_main_process,
|
||||
)
|
||||
|
||||
|
||||
@ -478,7 +480,17 @@ def train_one_epoch(
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
params.tot_loss = 0.0
|
||||
params.tot_frames = 0.0
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
|
||||
with tqdm(
|
||||
train_dl,
|
||||
initial=0,
|
||||
total=train_dl.sampler.num_cuts,
|
||||
leave=False,
|
||||
unit="cut", # a data fragment concept comes from "lhotse"
|
||||
dynamic_ncols=True,
|
||||
disable=not is_main_process(),
|
||||
) as t:
|
||||
for batch_idx, batch in enumerate(t):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
@ -514,12 +526,18 @@ def train_one_epoch(
|
||||
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
|
||||
tot_avg_att_loss = tot_att_loss / tot_frames
|
||||
|
||||
t.set_postfix(total_avg_loss=tot_avg_loss)
|
||||
t.update(batch_size)
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
|
||||
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
|
||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
||||
f"batch avg ctc loss \
|
||||
{ctc_loss_cpu/params.train_frames:.4f}, "
|
||||
f"batch avg att loss \
|
||||
{att_loss_cpu/params.train_frames:.4f}, "
|
||||
f"batch avg loss \
|
||||
{loss_cpu/params.train_frames:.4f}, "
|
||||
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
|
||||
f"total avg att loss: {tot_avg_att_loss:.4f}, "
|
||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||
@ -692,7 +710,7 @@ def run(rank, world_size, args):
|
||||
)
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
if rank == 0:
|
||||
if is_main_process():
|
||||
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
@ -411,3 +411,20 @@ def write_error_stats(
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
return float(tot_err_rate)
|
||||
|
||||
|
||||
def is_main_process():
|
||||
"""Check if the current process is main.
|
||||
|
||||
- For DistributedDataParallel (DDP) mode:
|
||||
The current is main process if its rank is 0, and its rank get from
|
||||
os.environ["RANK"] (which already be set by DDP mode).
|
||||
- For standard mode and others:
|
||||
The os.environ["RANK"] is None, and
|
||||
"""
|
||||
if os.environ.get('RANK') is None:
|
||||
return True
|
||||
elif os.environ.get("RANK") == "0":
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
Loading…
x
Reference in New Issue
Block a user