mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-16 12:42:20 +00:00
add tqdm for getting a progress bar of the training
This commit is contained in:
parent
a80e58e15d
commit
236b305698
@ -28,6 +28,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from tqdm import tqdm
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
@ -46,6 +47,7 @@ from icefall.utils import (
|
|||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
is_main_process,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -478,7 +480,17 @@ def train_one_epoch(
|
|||||||
tot_frames = 0.0 # sum of frames over all batches
|
tot_frames = 0.0 # sum of frames over all batches
|
||||||
params.tot_loss = 0.0
|
params.tot_loss = 0.0
|
||||||
params.tot_frames = 0.0
|
params.tot_frames = 0.0
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
|
||||||
|
with tqdm(
|
||||||
|
train_dl,
|
||||||
|
initial=0,
|
||||||
|
total=train_dl.sampler.num_cuts,
|
||||||
|
leave=False,
|
||||||
|
unit="cut", # a data fragment concept comes from "lhotse"
|
||||||
|
dynamic_ncols=True,
|
||||||
|
disable=not is_main_process(),
|
||||||
|
) as t:
|
||||||
|
for batch_idx, batch in enumerate(t):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
@ -514,12 +526,18 @@ def train_one_epoch(
|
|||||||
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
|
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
|
||||||
tot_avg_att_loss = tot_att_loss / tot_frames
|
tot_avg_att_loss = tot_att_loss / tot_frames
|
||||||
|
|
||||||
|
t.set_postfix(total_avg_loss=tot_avg_loss)
|
||||||
|
t.update(batch_size)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||||
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
|
f"batch avg ctc loss \
|
||||||
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
|
{ctc_loss_cpu/params.train_frames:.4f}, "
|
||||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
f"batch avg att loss \
|
||||||
|
{att_loss_cpu/params.train_frames:.4f}, "
|
||||||
|
f"batch avg loss \
|
||||||
|
{loss_cpu/params.train_frames:.4f}, "
|
||||||
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
|
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
|
||||||
f"total avg att loss: {tot_avg_att_loss:.4f}, "
|
f"total avg att loss: {tot_avg_att_loss:.4f}, "
|
||||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||||
@ -692,7 +710,7 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
if rank == 0:
|
if is_main_process():
|
||||||
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
||||||
|
|
||||||
params.cur_epoch = epoch
|
params.cur_epoch = epoch
|
||||||
|
@ -411,3 +411,20 @@ def write_error_stats(
|
|||||||
|
|
||||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
return float(tot_err_rate)
|
return float(tot_err_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
"""Check if the current process is main.
|
||||||
|
|
||||||
|
- For DistributedDataParallel (DDP) mode:
|
||||||
|
The current is main process if its rank is 0, and its rank get from
|
||||||
|
os.environ["RANK"] (which already be set by DDP mode).
|
||||||
|
- For standard mode and others:
|
||||||
|
The os.environ["RANK"] is None, and
|
||||||
|
"""
|
||||||
|
if os.environ.get('RANK') is None:
|
||||||
|
return True
|
||||||
|
elif os.environ.get("RANK") == "0":
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
Loading…
x
Reference in New Issue
Block a user