From 236b305698fbc0c88d723ed5ba297912f5ce916e Mon Sep 17 00:00:00 2001 From: songmeixu Date: Mon, 20 Sep 2021 20:12:27 +0800 Subject: [PATCH] add tqdm for getting a progress bar of the training --- egs/librispeech/ASR/conformer_ctc/train.py | 242 +++++++++++---------- icefall/utils.py | 17 ++ 2 files changed, 147 insertions(+), 112 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 80b2d924a..4bfcbfed3 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -28,6 +28,7 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn +from tqdm import tqdm from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed @@ -46,6 +47,7 @@ from icefall.utils import ( encode_supervisions, setup_logger, str2bool, + is_main_process, ) @@ -478,126 +480,142 @@ def train_one_epoch( tot_frames = 0.0 # sum of frames over all batches params.tot_loss = 0.0 params.tot_frames = 0.0 - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - loss, ctc_loss, att_loss = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) + with tqdm( + train_dl, + initial=0, + total=train_dl.sampler.num_cuts, + leave=False, + unit="cut", # a data fragment concept comes from "lhotse" + dynamic_ncols=True, + disable=not is_main_process(), + ) as t: + for batch_idx, batch in enumerate(t): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - loss_cpu = loss.detach().cpu().item() - ctc_loss_cpu = ctc_loss.detach().cpu().item() - att_loss_cpu = att_loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_ctc_loss += ctc_loss_cpu - tot_att_loss += att_loss_cpu - - params.tot_frames += params.train_frames - params.tot_loss += loss_cpu - - tot_avg_loss = tot_loss / tot_frames - tot_avg_ctc_loss = tot_ctc_loss / tot_frames - tot_avg_att_loss = tot_att_loss / tot_frames - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, " - f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, " - f"total avg att loss: {tot_avg_att_loss:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/current_ctc_loss", - ctc_loss_cpu / params.train_frames, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/current_att_loss", - att_loss_cpu / params.train_frames, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_ctc_loss", - tot_avg_ctc_loss, - params.batch_idx_train, - ) - - tb_writer.add_scalar( - "train/tot_avg_att_loss", - tot_avg_att_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, - ) - if batch_idx > 0 and batch_idx % params.reset_interval == 0: - tot_loss = 0.0 # sum of losses over all batches - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 - - tot_frames = 0.0 # sum of frames over all batches - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + loss, ctc_loss, att_loss = compute_loss( params=params, model=model, + batch=batch, graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, + is_training=True, ) - model.train() - logging.info( - f"Epoch {params.cur_epoch}, " - f"valid ctc loss {params.valid_ctc_loss:.4f}," - f"valid att loss {params.valid_att_loss:.4f}," - f"valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) - if tb_writer is not None: - tb_writer.add_scalar( - "train/valid_ctc_loss", - params.valid_ctc_loss, - params.batch_idx_train, + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + loss_cpu = loss.detach().cpu().item() + ctc_loss_cpu = ctc_loss.detach().cpu().item() + att_loss_cpu = att_loss.detach().cpu().item() + + tot_frames += params.train_frames + tot_loss += loss_cpu + tot_ctc_loss += ctc_loss_cpu + tot_att_loss += att_loss_cpu + + params.tot_frames += params.train_frames + params.tot_loss += loss_cpu + + tot_avg_loss = tot_loss / tot_frames + tot_avg_ctc_loss = tot_ctc_loss / tot_frames + tot_avg_att_loss = tot_att_loss / tot_frames + + t.set_postfix(total_avg_loss=tot_avg_loss) + t.update(batch_size) + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"batch avg ctc loss \ + {ctc_loss_cpu/params.train_frames:.4f}, " + f"batch avg att loss \ + {att_loss_cpu/params.train_frames:.4f}, " + f"batch avg loss \ + {loss_cpu/params.train_frames:.4f}, " + f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, " + f"total avg att loss: {tot_avg_att_loss:.4f}, " + f"total avg loss: {tot_avg_loss:.4f}, " + f"batch size: {batch_size}" ) - tb_writer.add_scalar( - "train/valid_att_loss", - params.valid_att_loss, - params.batch_idx_train, + + if tb_writer is not None: + tb_writer.add_scalar( + "train/current_ctc_loss", + ctc_loss_cpu / params.train_frames, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/current_att_loss", + att_loss_cpu / params.train_frames, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/current_loss", + loss_cpu / params.train_frames, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/tot_avg_ctc_loss", + tot_avg_ctc_loss, + params.batch_idx_train, + ) + + tb_writer.add_scalar( + "train/tot_avg_att_loss", + tot_avg_att_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/tot_avg_loss", + tot_avg_loss, + params.batch_idx_train, + ) + if batch_idx > 0 and batch_idx % params.reset_interval == 0: + tot_loss = 0.0 # sum of losses over all batches + tot_ctc_loss = 0.0 + tot_att_loss = 0.0 + + tot_frames = 0.0 # sum of frames over all batches + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, ) - tb_writer.add_scalar( - "train/valid_loss", - params.valid_loss, - params.batch_idx_train, + model.train() + logging.info( + f"Epoch {params.cur_epoch}, " + f"valid ctc loss {params.valid_ctc_loss:.4f}," + f"valid att loss {params.valid_att_loss:.4f}," + f"valid loss {params.valid_loss:.4f}," + f" best valid loss: {params.best_valid_loss:.4f} " + f"best valid epoch: {params.best_valid_epoch}" ) + if tb_writer is not None: + tb_writer.add_scalar( + "train/valid_ctc_loss", + params.valid_ctc_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/valid_att_loss", + params.valid_att_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/valid_loss", + params.valid_loss, + params.batch_idx_train, + ) params.train_loss = params.tot_loss / params.tot_frames @@ -692,7 +710,7 @@ def run(rank, world_size, args): ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - if rank == 0: + if is_main_process(): logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) params.cur_epoch = epoch diff --git a/icefall/utils.py b/icefall/utils.py index 2324201c3..6960724d7 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -411,3 +411,20 @@ def write_error_stats( print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) return float(tot_err_rate) + + +def is_main_process(): + """Check if the current process is main. + + - For DistributedDataParallel (DDP) mode: + The current is main process if its rank is 0, and its rank get from + os.environ["RANK"] (which already be set by DDP mode). + - For standard mode and others: + The os.environ["RANK"] is None, and + """ + if os.environ.get('RANK') is None: + return True + elif os.environ.get("RANK") == "0": + return True + else: + return False