Zero out the gradient of decoder/joiner for auxiliary losses.

This commit is contained in:
Fangjun Kuang 2022-02-22 19:44:01 +08:00
parent 76632bddfe
commit 4319a187b3

View File

@ -56,6 +56,7 @@ from librispeech import LibriSpeech
from model import Transducer from model import Transducer
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.parallel.distributed import _find_tensors
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam
@ -611,7 +612,23 @@ def train_one_epoch(
optimizer.zero_grad() optimizer.zero_grad()
(transducer_loss + aux_loss * params.lambda_aux).backward() if hasattr(model, "module"):
out_tensors = list(_find_tensors(aux_loss))
model.reducer.prepare_for_backward(out_tensors)
model2 = model.module
else:
model2 = model
(aux_loss * params.lambda_aux).backward(retain_graph=True)
# zero out the grad for decoder and joiner
model2.decoder.zero_grad()
model2.joiner.zero_grad()
if hasattr(model, "module"):
out_tensors = list(_find_tensors(transducer_loss))
model.reducer.prepare_for_backward(out_tensors)
transducer_loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
@ -888,7 +905,27 @@ def scan_pessimistic_batches_for_oom(
is_training=True, is_training=True,
) )
(transducer_loss + aux_loss * params.lambda_aux).backward() libri = is_libri(batch["supervisions"]["cut"][0])
# see https://github.com/pytorch/pytorch/issues/47260#issuecomment-789127532 # noqa
# for details of `_find_tensors()` and `prepare_for_backward()`.
if hasattr(model, "module"):
out_tensors = list(_find_tensors(aux_loss))
model.reducer.prepare_for_backward(out_tensors)
model2 = model.module
else:
model2 = model
(aux_loss * params.lambda_aux).backward(retain_graph=True)
# zero out the grad for decoder and joiner
model2.decoder.zero_grad()
model2.joiner.zero_grad()
if hasattr(model, "module"):
out_tensors = list(_find_tensors(transducer_loss))
model.reducer.prepare_for_backward(out_tensors)
transducer_loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()