mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Minor fixes.
This commit is contained in:
parent
72c0220830
commit
14e0886559
@ -194,7 +194,10 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
saved_params = load_checkpoint(
|
saved_params = load_checkpoint(
|
||||||
filename, model=model, optimizer=optimizer, scheduler=scheduler,
|
filename,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = [
|
keys = [
|
||||||
@ -512,6 +515,7 @@ def train_one_epoch(
|
|||||||
params.tot_frames = 0.0
|
params.tot_frames = 0.0
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
if batch_idx == 0:
|
if batch_idx == 0:
|
||||||
|
logging.info("save a batch for OOM handling")
|
||||||
# Use this batch to replace the batch that's causing OOM
|
# Use this batch to replace the batch that's causing OOM
|
||||||
params.saved_batch = batch
|
params.saved_batch = batch
|
||||||
|
|
||||||
@ -597,7 +601,9 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/tot_avg_loss", tot_avg_loss, params.batch_idx_train,
|
"train/tot_avg_loss",
|
||||||
|
tot_avg_loss,
|
||||||
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||||
tot_loss = 0.0 # sum of losses over all batches
|
tot_loss = 0.0 # sum of losses over all batches
|
||||||
@ -646,6 +652,9 @@ def train_one_epoch(
|
|||||||
params.best_train_epoch = params.cur_epoch
|
params.best_train_epoch = params.cur_epoch
|
||||||
params.best_train_loss = params.train_loss
|
params.best_train_loss = params.train_loss
|
||||||
|
|
||||||
|
if "saved_batch" in params:
|
||||||
|
del params["saved_batch"]
|
||||||
|
|
||||||
|
|
||||||
def run(rank, world_size, args):
|
def run(rank, world_size, args):
|
||||||
"""
|
"""
|
||||||
@ -749,10 +758,12 @@ def run(rank, world_size, args):
|
|||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
del params.saved_batch
|
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params, model=model, optimizer=optimizer, rank=rank,
|
params=params,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user