mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 02:52:18 +00:00
Zero out the gradient of decoder/joiner for auxiliary losses.
This commit is contained in:
parent
76632bddfe
commit
4319a187b3
@ -56,6 +56,7 @@ from librispeech import LibriSpeech
|
||||
from model import Transducer
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.parallel.distributed import _find_tensors
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
@ -611,7 +612,23 @@ def train_one_epoch(
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
(transducer_loss + aux_loss * params.lambda_aux).backward()
|
||||
if hasattr(model, "module"):
|
||||
out_tensors = list(_find_tensors(aux_loss))
|
||||
model.reducer.prepare_for_backward(out_tensors)
|
||||
model2 = model.module
|
||||
else:
|
||||
model2 = model
|
||||
|
||||
(aux_loss * params.lambda_aux).backward(retain_graph=True)
|
||||
# zero out the grad for decoder and joiner
|
||||
model2.decoder.zero_grad()
|
||||
model2.joiner.zero_grad()
|
||||
|
||||
if hasattr(model, "module"):
|
||||
out_tensors = list(_find_tensors(transducer_loss))
|
||||
model.reducer.prepare_for_backward(out_tensors)
|
||||
|
||||
transducer_loss.backward()
|
||||
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
@ -888,7 +905,27 @@ def scan_pessimistic_batches_for_oom(
|
||||
is_training=True,
|
||||
)
|
||||
|
||||
(transducer_loss + aux_loss * params.lambda_aux).backward()
|
||||
libri = is_libri(batch["supervisions"]["cut"][0])
|
||||
|
||||
# see https://github.com/pytorch/pytorch/issues/47260#issuecomment-789127532 # noqa
|
||||
# for details of `_find_tensors()` and `prepare_for_backward()`.
|
||||
if hasattr(model, "module"):
|
||||
out_tensors = list(_find_tensors(aux_loss))
|
||||
model.reducer.prepare_for_backward(out_tensors)
|
||||
model2 = model.module
|
||||
else:
|
||||
model2 = model
|
||||
|
||||
(aux_loss * params.lambda_aux).backward(retain_graph=True)
|
||||
# zero out the grad for decoder and joiner
|
||||
model2.decoder.zero_grad()
|
||||
model2.joiner.zero_grad()
|
||||
|
||||
if hasattr(model, "module"):
|
||||
out_tensors = list(_find_tensors(transducer_loss))
|
||||
model.reducer.prepare_for_backward(out_tensors)
|
||||
|
||||
transducer_loss.backward()
|
||||
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
|
Loading…
x
Reference in New Issue
Block a user