mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-15 20:22:42 +00:00
add tqdm for getting a progress bar of the training
This commit is contained in:
parent
a80e58e15d
commit
236b305698
@ -28,6 +28,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from tqdm import tqdm
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
@ -46,6 +47,7 @@ from icefall.utils import (
|
|||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
|
is_main_process,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -478,126 +480,142 @@ def train_one_epoch(
|
|||||||
tot_frames = 0.0 # sum of frames over all batches
|
tot_frames = 0.0 # sum of frames over all batches
|
||||||
params.tot_loss = 0.0
|
params.tot_loss = 0.0
|
||||||
params.tot_frames = 0.0
|
params.tot_frames = 0.0
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
|
||||||
params.batch_idx_train += 1
|
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
|
||||||
|
|
||||||
loss, ctc_loss, att_loss = compute_loss(
|
with tqdm(
|
||||||
params=params,
|
train_dl,
|
||||||
model=model,
|
initial=0,
|
||||||
batch=batch,
|
total=train_dl.sampler.num_cuts,
|
||||||
graph_compiler=graph_compiler,
|
leave=False,
|
||||||
is_training=True,
|
unit="cut", # a data fragment concept comes from "lhotse"
|
||||||
)
|
dynamic_ncols=True,
|
||||||
|
disable=not is_main_process(),
|
||||||
|
) as t:
|
||||||
|
for batch_idx, batch in enumerate(t):
|
||||||
|
params.batch_idx_train += 1
|
||||||
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
loss, ctc_loss, att_loss = compute_loss(
|
||||||
# in the batch and there is no normalization to it so far.
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
loss_cpu = loss.detach().cpu().item()
|
|
||||||
ctc_loss_cpu = ctc_loss.detach().cpu().item()
|
|
||||||
att_loss_cpu = att_loss.detach().cpu().item()
|
|
||||||
|
|
||||||
tot_frames += params.train_frames
|
|
||||||
tot_loss += loss_cpu
|
|
||||||
tot_ctc_loss += ctc_loss_cpu
|
|
||||||
tot_att_loss += att_loss_cpu
|
|
||||||
|
|
||||||
params.tot_frames += params.train_frames
|
|
||||||
params.tot_loss += loss_cpu
|
|
||||||
|
|
||||||
tot_avg_loss = tot_loss / tot_frames
|
|
||||||
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
|
|
||||||
tot_avg_att_loss = tot_att_loss / tot_frames
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
|
||||||
logging.info(
|
|
||||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
|
||||||
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
|
|
||||||
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
|
|
||||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
|
||||||
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
|
|
||||||
f"total avg att loss: {tot_avg_att_loss:.4f}, "
|
|
||||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
|
||||||
f"batch size: {batch_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if tb_writer is not None:
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/current_ctc_loss",
|
|
||||||
ctc_loss_cpu / params.train_frames,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/current_att_loss",
|
|
||||||
att_loss_cpu / params.train_frames,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/current_loss",
|
|
||||||
loss_cpu / params.train_frames,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/tot_avg_ctc_loss",
|
|
||||||
tot_avg_ctc_loss,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/tot_avg_att_loss",
|
|
||||||
tot_avg_att_loss,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/tot_avg_loss",
|
|
||||||
tot_avg_loss,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
|
||||||
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
|
||||||
tot_loss = 0.0 # sum of losses over all batches
|
|
||||||
tot_ctc_loss = 0.0
|
|
||||||
tot_att_loss = 0.0
|
|
||||||
|
|
||||||
tot_frames = 0.0 # sum of frames over all batches
|
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
|
||||||
compute_validation_loss(
|
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
batch=batch,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
valid_dl=valid_dl,
|
is_training=True,
|
||||||
world_size=world_size,
|
|
||||||
)
|
)
|
||||||
model.train()
|
|
||||||
logging.info(
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
f"Epoch {params.cur_epoch}, "
|
# in the batch and there is no normalization to it so far.
|
||||||
f"valid ctc loss {params.valid_ctc_loss:.4f},"
|
|
||||||
f"valid att loss {params.valid_att_loss:.4f},"
|
optimizer.zero_grad()
|
||||||
f"valid loss {params.valid_loss:.4f},"
|
loss.backward()
|
||||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
f"best valid epoch: {params.best_valid_epoch}"
|
optimizer.step()
|
||||||
)
|
|
||||||
if tb_writer is not None:
|
loss_cpu = loss.detach().cpu().item()
|
||||||
tb_writer.add_scalar(
|
ctc_loss_cpu = ctc_loss.detach().cpu().item()
|
||||||
"train/valid_ctc_loss",
|
att_loss_cpu = att_loss.detach().cpu().item()
|
||||||
params.valid_ctc_loss,
|
|
||||||
params.batch_idx_train,
|
tot_frames += params.train_frames
|
||||||
|
tot_loss += loss_cpu
|
||||||
|
tot_ctc_loss += ctc_loss_cpu
|
||||||
|
tot_att_loss += att_loss_cpu
|
||||||
|
|
||||||
|
params.tot_frames += params.train_frames
|
||||||
|
params.tot_loss += loss_cpu
|
||||||
|
|
||||||
|
tot_avg_loss = tot_loss / tot_frames
|
||||||
|
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
|
||||||
|
tot_avg_att_loss = tot_att_loss / tot_frames
|
||||||
|
|
||||||
|
t.set_postfix(total_avg_loss=tot_avg_loss)
|
||||||
|
t.update(batch_size)
|
||||||
|
|
||||||
|
if batch_idx % params.log_interval == 0:
|
||||||
|
logging.info(
|
||||||
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||||
|
f"batch avg ctc loss \
|
||||||
|
{ctc_loss_cpu/params.train_frames:.4f}, "
|
||||||
|
f"batch avg att loss \
|
||||||
|
{att_loss_cpu/params.train_frames:.4f}, "
|
||||||
|
f"batch avg loss \
|
||||||
|
{loss_cpu/params.train_frames:.4f}, "
|
||||||
|
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
|
||||||
|
f"total avg att loss: {tot_avg_att_loss:.4f}, "
|
||||||
|
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||||
|
f"batch size: {batch_size}"
|
||||||
)
|
)
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/valid_att_loss",
|
if tb_writer is not None:
|
||||||
params.valid_att_loss,
|
tb_writer.add_scalar(
|
||||||
params.batch_idx_train,
|
"train/current_ctc_loss",
|
||||||
|
ctc_loss_cpu / params.train_frames,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/current_att_loss",
|
||||||
|
att_loss_cpu / params.train_frames,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/current_loss",
|
||||||
|
loss_cpu / params.train_frames,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/tot_avg_ctc_loss",
|
||||||
|
tot_avg_ctc_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/tot_avg_att_loss",
|
||||||
|
tot_avg_att_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/tot_avg_loss",
|
||||||
|
tot_avg_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||||
|
tot_loss = 0.0 # sum of losses over all batches
|
||||||
|
tot_ctc_loss = 0.0
|
||||||
|
tot_att_loss = 0.0
|
||||||
|
|
||||||
|
tot_frames = 0.0 # sum of frames over all batches
|
||||||
|
|
||||||
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
|
compute_validation_loss(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
|
valid_dl=valid_dl,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
tb_writer.add_scalar(
|
model.train()
|
||||||
"train/valid_loss",
|
logging.info(
|
||||||
params.valid_loss,
|
f"Epoch {params.cur_epoch}, "
|
||||||
params.batch_idx_train,
|
f"valid ctc loss {params.valid_ctc_loss:.4f},"
|
||||||
|
f"valid att loss {params.valid_att_loss:.4f},"
|
||||||
|
f"valid loss {params.valid_loss:.4f},"
|
||||||
|
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||||
|
f"best valid epoch: {params.best_valid_epoch}"
|
||||||
)
|
)
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/valid_ctc_loss",
|
||||||
|
params.valid_ctc_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/valid_att_loss",
|
||||||
|
params.valid_att_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/valid_loss",
|
||||||
|
params.valid_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
|
||||||
params.train_loss = params.tot_loss / params.tot_frames
|
params.train_loss = params.tot_loss / params.tot_frames
|
||||||
|
|
||||||
@ -692,7 +710,7 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
if rank == 0:
|
if is_main_process():
|
||||||
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
||||||
|
|
||||||
params.cur_epoch = epoch
|
params.cur_epoch = epoch
|
||||||
|
@ -411,3 +411,20 @@ def write_error_stats(
|
|||||||
|
|
||||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
return float(tot_err_rate)
|
return float(tot_err_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
"""Check if the current process is main.
|
||||||
|
|
||||||
|
- For DistributedDataParallel (DDP) mode:
|
||||||
|
The current is main process if its rank is 0, and its rank get from
|
||||||
|
os.environ["RANK"] (which already be set by DDP mode).
|
||||||
|
- For standard mode and others:
|
||||||
|
The os.environ["RANK"] is None, and
|
||||||
|
"""
|
||||||
|
if os.environ.get('RANK') is None:
|
||||||
|
return True
|
||||||
|
elif os.environ.get("RANK") == "0":
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
Loading…
x
Reference in New Issue
Block a user