mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix oom handling.
This commit is contained in:
parent
36ac512d00
commit
72c0220830
@ -3,6 +3,7 @@
|
||||
# This is just at the very beginning ...
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
@ -193,10 +194,7 @@ def load_checkpoint_if_available(
|
||||
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
filename, model=model, optimizer=optimizer, scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
@ -281,9 +279,13 @@ def compute_loss_impl(
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
try:
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
with torch.set_grad_enabled(is_training):
|
||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||
nnet_output, encoder_memory, memory_mask = model(
|
||||
feature, supervisions
|
||||
)
|
||||
# nnet_output is [N, T, C]
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
@ -329,10 +331,23 @@ def compute_loss_impl(
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
|
||||
loss = (
|
||||
1.0 - params.att_rate
|
||||
) * ctc_loss + params.att_rate * att_loss
|
||||
else:
|
||||
loss = ctc_loss
|
||||
att_loss = torch.tensor([0])
|
||||
except RuntimeError as ex:
|
||||
try:
|
||||
del nnet_output
|
||||
del encoder_memory
|
||||
del dense_fsa_vec
|
||||
del ctc_loss
|
||||
del att_loss
|
||||
del loss
|
||||
except NameError as ne:
|
||||
pass
|
||||
raise ex
|
||||
|
||||
# train_frames and valid_frames are used for printing.
|
||||
if is_training:
|
||||
@ -364,6 +379,7 @@ def compute_loss(
|
||||
if "out of memory" not in str(ex):
|
||||
raise ex
|
||||
|
||||
logging.exception(ex)
|
||||
s = f"\nCaught exception: {str(ex)}\n"
|
||||
total_duration = 0.0
|
||||
max_cut_duration = 0.0
|
||||
@ -379,8 +395,12 @@ def compute_loss(
|
||||
for p in model.parameters():
|
||||
if p.grad is not None:
|
||||
del p.grad # free some memory
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
gc.collect()
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/18853#issuecomment-583779161
|
||||
return compute_loss_impl(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -577,9 +597,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_loss",
|
||||
tot_avg_loss,
|
||||
params.batch_idx_train,
|
||||
"train/tot_avg_loss", tot_avg_loss, params.batch_idx_train,
|
||||
)
|
||||
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
@ -734,10 +752,7 @@ def run(rank, world_size, args):
|
||||
del params.saved_batch
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
rank=rank,
|
||||
params=params, model=model, optimizer=optimizer, rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
Loading…
x
Reference in New Issue
Block a user