From 4319a187b392d2a0fb6c211c96f7b30639eb3234 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 22 Feb 2022 19:44:01 +0800 Subject: [PATCH] Zero out the gradient of decoder/joiner for auxiliary losses. --- .../ASR/transducer_stateless_aux_kl/train.py | 41 ++++++++++++++++++- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py b/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py index 71bf8da86..0f7c5372a 100755 --- a/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py +++ b/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py @@ -56,6 +56,7 @@ from librispeech import LibriSpeech from model import Transducer from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.parallel.distributed import _find_tensors from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam @@ -611,7 +612,23 @@ def train_one_epoch( optimizer.zero_grad() - (transducer_loss + aux_loss * params.lambda_aux).backward() + if hasattr(model, "module"): + out_tensors = list(_find_tensors(aux_loss)) + model.reducer.prepare_for_backward(out_tensors) + model2 = model.module + else: + model2 = model + + (aux_loss * params.lambda_aux).backward(retain_graph=True) + # zero out the grad for decoder and joiner + model2.decoder.zero_grad() + model2.joiner.zero_grad() + + if hasattr(model, "module"): + out_tensors = list(_find_tensors(transducer_loss)) + model.reducer.prepare_for_backward(out_tensors) + + transducer_loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() @@ -888,7 +905,27 @@ def scan_pessimistic_batches_for_oom( is_training=True, ) - (transducer_loss + aux_loss * params.lambda_aux).backward() + libri = is_libri(batch["supervisions"]["cut"][0]) + + # see https://github.com/pytorch/pytorch/issues/47260#issuecomment-789127532 # noqa + # for details of `_find_tensors()` and `prepare_for_backward()`. + if hasattr(model, "module"): + out_tensors = list(_find_tensors(aux_loss)) + model.reducer.prepare_for_backward(out_tensors) + model2 = model.module + else: + model2 = model + + (aux_loss * params.lambda_aux).backward(retain_graph=True) + # zero out the grad for decoder and joiner + model2.decoder.zero_grad() + model2.joiner.zero_grad() + + if hasattr(model, "module"): + out_tensors = list(_find_tensors(transducer_loss)) + model.reducer.prepare_for_backward(out_tensors) + + transducer_loss.backward() clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step()