add tqdm for getting a progress bar of the training

This commit is contained in:
songmeixu 2021-09-20 20:12:27 +08:00
parent a80e58e15d
commit 236b305698
2 changed files with 147 additions and 112 deletions

View File

@ -28,6 +28,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
@ -46,6 +47,7 @@ from icefall.utils import (
encode_supervisions, encode_supervisions,
setup_logger, setup_logger,
str2bool, str2bool,
is_main_process,
) )
@ -478,126 +480,142 @@ def train_one_epoch(
tot_frames = 0.0 # sum of frames over all batches tot_frames = 0.0 # sum of frames over all batches
params.tot_loss = 0.0 params.tot_loss = 0.0
params.tot_frames = 0.0 params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, ctc_loss, att_loss = compute_loss( with tqdm(
params=params, train_dl,
model=model, initial=0,
batch=batch, total=train_dl.sampler.num_cuts,
graph_compiler=graph_compiler, leave=False,
is_training=True, unit="cut", # a data fragment concept comes from "lhotse"
) dynamic_ncols=True,
disable=not is_main_process(),
) as t:
for batch_idx, batch in enumerate(t):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
# NOTE: We use reduction==sum and loss is computed over utterances loss, ctc_loss, att_loss = compute_loss(
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
ctc_loss_cpu = ctc_loss.detach().cpu().item()
att_loss_cpu = att_loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss_cpu
tot_att_loss += att_loss_cpu
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/current_ctc_loss",
ctc_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_att_loss",
att_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_ctc_loss",
tot_avg_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_att_loss",
tot_avg_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
params=params, params=params,
model=model, model=model,
batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
valid_dl=valid_dl, is_training=True,
world_size=world_size,
) )
model.train()
logging.info( # NOTE: We use reduction==sum and loss is computed over utterances
f"Epoch {params.cur_epoch}, " # in the batch and there is no normalization to it so far.
f"valid ctc loss {params.valid_ctc_loss:.4f},"
f"valid att loss {params.valid_att_loss:.4f}," optimizer.zero_grad()
f"valid loss {params.valid_loss:.4f}," loss.backward()
f" best valid loss: {params.best_valid_loss:.4f} " clip_grad_norm_(model.parameters(), 5.0, 2.0)
f"best valid epoch: {params.best_valid_epoch}" optimizer.step()
)
if tb_writer is not None: loss_cpu = loss.detach().cpu().item()
tb_writer.add_scalar( ctc_loss_cpu = ctc_loss.detach().cpu().item()
"train/valid_ctc_loss", att_loss_cpu = att_loss.detach().cpu().item()
params.valid_ctc_loss,
params.batch_idx_train, tot_frames += params.train_frames
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss_cpu
tot_att_loss += att_loss_cpu
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames
t.set_postfix(total_avg_loss=tot_avg_loss)
t.update(batch_size)
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg ctc loss \
{ctc_loss_cpu/params.train_frames:.4f}, "
f"batch avg att loss \
{att_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss \
{loss_cpu/params.train_frames:.4f}, "
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
) )
tb_writer.add_scalar(
"train/valid_att_loss", if tb_writer is not None:
params.valid_att_loss, tb_writer.add_scalar(
params.batch_idx_train, "train/current_ctc_loss",
ctc_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_att_loss",
att_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_ctc_loss",
tot_avg_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_att_loss",
tot_avg_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
) )
tb_writer.add_scalar( model.train()
"train/valid_loss", logging.info(
params.valid_loss, f"Epoch {params.cur_epoch}, "
params.batch_idx_train, f"valid ctc loss {params.valid_ctc_loss:.4f},"
f"valid att loss {params.valid_att_loss:.4f},"
f"valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
) )
if tb_writer is not None:
tb_writer.add_scalar(
"train/valid_ctc_loss",
params.valid_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_att_loss",
params.valid_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
params.batch_idx_train,
)
params.train_loss = params.tot_loss / params.tot_frames params.train_loss = params.tot_loss / params.tot_frames
@ -692,7 +710,7 @@ def run(rank, world_size, args):
) )
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0: if is_main_process():
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch params.cur_epoch = epoch

View File

@ -411,3 +411,20 @@ def write_error_stats(
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate) return float(tot_err_rate)
def is_main_process():
"""Check if the current process is main.
- For DistributedDataParallel (DDP) mode:
The current is main process if its rank is 0, and its rank get from
os.environ["RANK"] (which already be set by DDP mode).
- For standard mode and others:
The os.environ["RANK"] is None, and
"""
if os.environ.get('RANK') is None:
return True
elif os.environ.get("RANK") == "0":
return True
else:
return False