Zero out the gradient of decoder/joiner for auxiliary losses.

This commit is contained in:
Fangjun Kuang 2022-02-22 19:44:01 +08:00
parent 76632bddfe
commit 4319a187b3

View File

@ -56,6 +56,7 @@ from librispeech import LibriSpeech
from model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.parallel.distributed import _find_tensors
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
@ -611,7 +612,23 @@ def train_one_epoch(
optimizer.zero_grad()
(transducer_loss + aux_loss * params.lambda_aux).backward()
if hasattr(model, "module"):
out_tensors = list(_find_tensors(aux_loss))
model.reducer.prepare_for_backward(out_tensors)
model2 = model.module
else:
model2 = model
(aux_loss * params.lambda_aux).backward(retain_graph=True)
# zero out the grad for decoder and joiner
model2.decoder.zero_grad()
model2.joiner.zero_grad()
if hasattr(model, "module"):
out_tensors = list(_find_tensors(transducer_loss))
model.reducer.prepare_for_backward(out_tensors)
transducer_loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
@ -888,7 +905,27 @@ def scan_pessimistic_batches_for_oom(
is_training=True,
)
(transducer_loss + aux_loss * params.lambda_aux).backward()
libri = is_libri(batch["supervisions"]["cut"][0])
# see https://github.com/pytorch/pytorch/issues/47260#issuecomment-789127532 # noqa
# for details of `_find_tensors()` and `prepare_for_backward()`.
if hasattr(model, "module"):
out_tensors = list(_find_tensors(aux_loss))
model.reducer.prepare_for_backward(out_tensors)
model2 = model.module
else:
model2 = model
(aux_loss * params.lambda_aux).backward(retain_graph=True)
# zero out the grad for decoder and joiner
model2.decoder.zero_grad()
model2.joiner.zero_grad()
if hasattr(model, "module"):
out_tensors = list(_find_tensors(transducer_loss))
model.reducer.prepare_for_backward(out_tensors)
transducer_loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()