mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix oom handling.
This commit is contained in:
parent
36ac512d00
commit
72c0220830
@ -3,6 +3,7 @@
|
|||||||
# This is just at the very beginning ...
|
# This is just at the very beginning ...
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -193,10 +194,7 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
saved_params = load_checkpoint(
|
saved_params = load_checkpoint(
|
||||||
filename,
|
filename, model=model, optimizer=optimizer, scheduler=scheduler,
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
scheduler=scheduler,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = [
|
keys = [
|
||||||
@ -281,58 +279,75 @@ def compute_loss_impl(
|
|||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
try:
|
||||||
with torch.set_grad_enabled(is_training):
|
|
||||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
|
||||||
# nnet_output is [N, T, C]
|
|
||||||
|
|
||||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
supervisions = batch["supervisions"]
|
||||||
# different duration in decreasing order, required by
|
|
||||||
# `k2.intersect_dense` called in `k2.ctc_loss`
|
|
||||||
supervision_segments, texts = encode_supervisions(
|
|
||||||
supervisions, subsampling_factor=params.subsampling_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
token_ids = graph_compiler.texts_to_ids(texts)
|
|
||||||
|
|
||||||
decoding_graph = graph_compiler.compile(token_ids)
|
|
||||||
|
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
|
||||||
nnet_output,
|
|
||||||
supervision_segments,
|
|
||||||
allow_truncate=params.subsampling_factor - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
ctc_loss = k2.ctc_loss(
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
dense_fsa_vec=dense_fsa_vec,
|
|
||||||
output_beam=params.beam_size,
|
|
||||||
reduction=params.reduction,
|
|
||||||
use_double_scores=params.use_double_scores,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.att_rate != 0.0:
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
if hasattr(model, "module"):
|
nnet_output, encoder_memory, memory_mask = model(
|
||||||
att_loss = model.module.decoder_forward(
|
feature, supervisions
|
||||||
encoder_memory,
|
)
|
||||||
memory_mask,
|
# nnet_output is [N, T, C]
|
||||||
token_ids=token_ids,
|
|
||||||
sos_id=graph_compiler.sos_id,
|
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||||
eos_id=graph_compiler.eos_id,
|
# different duration in decreasing order, required by
|
||||||
)
|
# `k2.intersect_dense` called in `k2.ctc_loss`
|
||||||
else:
|
supervision_segments, texts = encode_supervisions(
|
||||||
att_loss = model.decoder_forward(
|
supervisions, subsampling_factor=params.subsampling_factor
|
||||||
encoder_memory,
|
)
|
||||||
memory_mask,
|
|
||||||
token_ids=token_ids,
|
token_ids = graph_compiler.texts_to_ids(texts)
|
||||||
sos_id=graph_compiler.sos_id,
|
|
||||||
eos_id=graph_compiler.eos_id,
|
decoding_graph = graph_compiler.compile(token_ids)
|
||||||
)
|
|
||||||
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
else:
|
nnet_output,
|
||||||
loss = ctc_loss
|
supervision_segments,
|
||||||
att_loss = torch.tensor([0])
|
allow_truncate=params.subsampling_factor - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctc_loss = k2.ctc_loss(
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
dense_fsa_vec=dense_fsa_vec,
|
||||||
|
output_beam=params.beam_size,
|
||||||
|
reduction=params.reduction,
|
||||||
|
use_double_scores=params.use_double_scores,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.att_rate != 0.0:
|
||||||
|
with torch.set_grad_enabled(is_training):
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
att_loss = model.module.decoder_forward(
|
||||||
|
encoder_memory,
|
||||||
|
memory_mask,
|
||||||
|
token_ids=token_ids,
|
||||||
|
sos_id=graph_compiler.sos_id,
|
||||||
|
eos_id=graph_compiler.eos_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
att_loss = model.decoder_forward(
|
||||||
|
encoder_memory,
|
||||||
|
memory_mask,
|
||||||
|
token_ids=token_ids,
|
||||||
|
sos_id=graph_compiler.sos_id,
|
||||||
|
eos_id=graph_compiler.eos_id,
|
||||||
|
)
|
||||||
|
loss = (
|
||||||
|
1.0 - params.att_rate
|
||||||
|
) * ctc_loss + params.att_rate * att_loss
|
||||||
|
else:
|
||||||
|
loss = ctc_loss
|
||||||
|
att_loss = torch.tensor([0])
|
||||||
|
except RuntimeError as ex:
|
||||||
|
try:
|
||||||
|
del nnet_output
|
||||||
|
del encoder_memory
|
||||||
|
del dense_fsa_vec
|
||||||
|
del ctc_loss
|
||||||
|
del att_loss
|
||||||
|
del loss
|
||||||
|
except NameError as ne:
|
||||||
|
pass
|
||||||
|
raise ex
|
||||||
|
|
||||||
# train_frames and valid_frames are used for printing.
|
# train_frames and valid_frames are used for printing.
|
||||||
if is_training:
|
if is_training:
|
||||||
@ -364,6 +379,7 @@ def compute_loss(
|
|||||||
if "out of memory" not in str(ex):
|
if "out of memory" not in str(ex):
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
|
logging.exception(ex)
|
||||||
s = f"\nCaught exception: {str(ex)}\n"
|
s = f"\nCaught exception: {str(ex)}\n"
|
||||||
total_duration = 0.0
|
total_duration = 0.0
|
||||||
max_cut_duration = 0.0
|
max_cut_duration = 0.0
|
||||||
@ -375,19 +391,23 @@ def compute_loss(
|
|||||||
s += f" max duration: {max_cut_duration:.3f} s \n"
|
s += f" max duration: {max_cut_duration:.3f} s \n"
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
# see https://github.com/pytorch/fairseq/blob/50a671f78d0c8de0392f924180db72ac9b41b801/fairseq/trainer.py#L283
|
# see https://github.com/pytorch/fairseq/blob/50a671f78d0c8de0392f924180db72ac9b41b801/fairseq/trainer.py#L283
|
||||||
for p in model.parameters():
|
for p in model.parameters():
|
||||||
if p.grad is not None:
|
if p.grad is not None:
|
||||||
del p.grad # free some memory
|
del p.grad # free some memory
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return compute_loss_impl(
|
torch.cuda.empty_cache()
|
||||||
params=params,
|
|
||||||
model=model,
|
gc.collect()
|
||||||
batch=params.saved_batch,
|
|
||||||
graph_compiler=graph_compiler,
|
# See https://github.com/pytorch/pytorch/issues/18853#issuecomment-583779161
|
||||||
is_training=is_training,
|
return compute_loss_impl(
|
||||||
)
|
params=params,
|
||||||
|
model=model,
|
||||||
|
batch=params.saved_batch,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
|
is_training=is_training,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
@ -577,9 +597,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/tot_avg_loss",
|
"train/tot_avg_loss", tot_avg_loss, params.batch_idx_train,
|
||||||
tot_avg_loss,
|
|
||||||
params.batch_idx_train,
|
|
||||||
)
|
)
|
||||||
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||||
tot_loss = 0.0 # sum of losses over all batches
|
tot_loss = 0.0 # sum of losses over all batches
|
||||||
@ -734,10 +752,7 @@ def run(rank, world_size, args):
|
|||||||
del params.saved_batch
|
del params.saved_batch
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params,
|
params=params, model=model, optimizer=optimizer, rank=rank,
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
rank=rank,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user