mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
Fix OOM handling when using DDP.
We have to disable batch norm layers. Otherwise, the process will hang indefinitely.
This commit is contained in:
parent
14e0886559
commit
21292066ec
@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 21.6b0
|
||||
rev: 21.7b0
|
||||
hooks:
|
||||
- id: black
|
||||
args: [--line-length=80]
|
||||
|
@ -869,7 +869,10 @@ class ConvolutionModule(nn.Module):
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
# NOTE(fangjun): The process hangs when using DDP
|
||||
# if we try to recover from CUDA OOM, so we disable
|
||||
# batchnorm layer here.
|
||||
# self.norm = nn.BatchNorm1d(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@ -899,7 +902,8 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.activation(self.norm(x))
|
||||
# x = self.activation(self.norm(x))
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
|
@ -153,7 +153,7 @@ def get_params() -> AttributeDict:
|
||||
"num_decoder_layers": 6,
|
||||
"is_espnet_structure": True,
|
||||
"mmi_loss": False,
|
||||
"use_feat_batchnorm": True,
|
||||
"use_feat_batchnorm": False,
|
||||
"lr_factor": 2.0,
|
||||
"warm_step": 30000,
|
||||
}
|
||||
@ -282,75 +282,59 @@ def compute_loss_impl(
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
try:
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
with torch.set_grad_enabled(is_training):
|
||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||
# nnet_output is [N, T, C]
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
# different duration in decreasing order, required by
|
||||
# `k2.intersect_dense` called in `k2.ctc_loss`
|
||||
supervision_segments, texts = encode_supervisions(
|
||||
supervisions, subsampling_factor=params.subsampling_factor
|
||||
)
|
||||
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
|
||||
decoding_graph = graph_compiler.compile(token_ids)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output,
|
||||
supervision_segments,
|
||||
allow_truncate=params.subsampling_factor - 1,
|
||||
)
|
||||
|
||||
ctc_loss = k2.ctc_loss(
|
||||
decoding_graph=decoding_graph,
|
||||
dense_fsa_vec=dense_fsa_vec,
|
||||
output_beam=params.beam_size,
|
||||
reduction=params.reduction,
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
|
||||
if params.att_rate != 0.0:
|
||||
with torch.set_grad_enabled(is_training):
|
||||
nnet_output, encoder_memory, memory_mask = model(
|
||||
feature, supervisions
|
||||
)
|
||||
# nnet_output is [N, T, C]
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
# different duration in decreasing order, required by
|
||||
# `k2.intersect_dense` called in `k2.ctc_loss`
|
||||
supervision_segments, texts = encode_supervisions(
|
||||
supervisions, subsampling_factor=params.subsampling_factor
|
||||
)
|
||||
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
|
||||
decoding_graph = graph_compiler.compile(token_ids)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output,
|
||||
supervision_segments,
|
||||
allow_truncate=params.subsampling_factor - 1,
|
||||
)
|
||||
|
||||
ctc_loss = k2.ctc_loss(
|
||||
decoding_graph=decoding_graph,
|
||||
dense_fsa_vec=dense_fsa_vec,
|
||||
output_beam=params.beam_size,
|
||||
reduction=params.reduction,
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
|
||||
if params.att_rate != 0.0:
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if hasattr(model, "module"):
|
||||
att_loss = model.module.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
else:
|
||||
att_loss = model.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
loss = (
|
||||
1.0 - params.att_rate
|
||||
) * ctc_loss + params.att_rate * att_loss
|
||||
else:
|
||||
loss = ctc_loss
|
||||
att_loss = torch.tensor([0])
|
||||
except RuntimeError as ex:
|
||||
try:
|
||||
del nnet_output
|
||||
del encoder_memory
|
||||
del dense_fsa_vec
|
||||
del ctc_loss
|
||||
del att_loss
|
||||
del loss
|
||||
except NameError as ne:
|
||||
pass
|
||||
raise ex
|
||||
if hasattr(model, "module"):
|
||||
att_loss = model.module.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
else:
|
||||
att_loss = model.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
|
||||
else:
|
||||
loss = ctc_loss
|
||||
att_loss = torch.tensor([0])
|
||||
|
||||
# train_frames and valid_frames are used for printing.
|
||||
if is_training:
|
||||
@ -394,11 +378,6 @@ def compute_loss(
|
||||
s += f" max duration: {max_cut_duration:.3f} s \n"
|
||||
logging.info(s)
|
||||
|
||||
# see https://github.com/pytorch/fairseq/blob/50a671f78d0c8de0392f924180db72ac9b41b801/fairseq/trainer.py#L283
|
||||
for p in model.parameters():
|
||||
if p.grad is not None:
|
||||
del p.grad # free some memory
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
gc.collect()
|
||||
|
Loading…
x
Reference in New Issue
Block a user