Fix OOM handling when using DDP.

We have to disable batch norm layers. Otherwise,
the process will hang indefinitely.
This commit is contained in:
Fangjun Kuang 2021-08-15 18:49:12 +08:00
parent 14e0886559
commit 21292066ec
3 changed files with 59 additions and 76 deletions

View File

@ -1,6 +1,6 @@
repos: repos:
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 21.6b0 rev: 21.7b0
hooks: hooks:
- id: black - id: black
args: [--line-length=80] args: [--line-length=80]

View File

@ -869,7 +869,10 @@ class ConvolutionModule(nn.Module):
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
self.norm = nn.BatchNorm1d(channels) # NOTE(fangjun): The process hangs when using DDP
# if we try to recover from CUDA OOM, so we disable
# batchnorm layer here.
# self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d( self.pointwise_conv2 = nn.Conv1d(
channels, channels,
channels, channels,
@ -899,7 +902,8 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.activation(self.norm(x)) # x = self.activation(self.norm(x))
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -153,7 +153,7 @@ def get_params() -> AttributeDict:
"num_decoder_layers": 6, "num_decoder_layers": 6,
"is_espnet_structure": True, "is_espnet_structure": True,
"mmi_loss": False, "mmi_loss": False,
"use_feat_batchnorm": True, "use_feat_batchnorm": False,
"lr_factor": 2.0, "lr_factor": 2.0,
"warm_step": 30000, "warm_step": 30000,
} }
@ -282,13 +282,10 @@ def compute_loss_impl(
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
try:
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model( nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
feature, supervisions
)
# nnet_output is [N, T, C] # nnet_output is [N, T, C]
# NOTE: We need `encode_supervisions` to sort sequences with # NOTE: We need `encode_supervisions` to sort sequences with
@ -334,23 +331,10 @@ def compute_loss_impl(
sos_id=graph_compiler.sos_id, sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id, eos_id=graph_compiler.eos_id,
) )
loss = ( loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
1.0 - params.att_rate
) * ctc_loss + params.att_rate * att_loss
else: else:
loss = ctc_loss loss = ctc_loss
att_loss = torch.tensor([0]) att_loss = torch.tensor([0])
except RuntimeError as ex:
try:
del nnet_output
del encoder_memory
del dense_fsa_vec
del ctc_loss
del att_loss
del loss
except NameError as ne:
pass
raise ex
# train_frames and valid_frames are used for printing. # train_frames and valid_frames are used for printing.
if is_training: if is_training:
@ -394,11 +378,6 @@ def compute_loss(
s += f" max duration: {max_cut_duration:.3f} s \n" s += f" max duration: {max_cut_duration:.3f} s \n"
logging.info(s) logging.info(s)
# see https://github.com/pytorch/fairseq/blob/50a671f78d0c8de0392f924180db72ac9b41b801/fairseq/trainer.py#L283
for p in model.parameters():
if p.grad is not None:
del p.grad # free some memory
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()