mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 02:52:18 +00:00
Zero out the gradient of decoder/joiner for auxiliary losses.
This commit is contained in:
parent
76632bddfe
commit
4319a187b3
@ -56,6 +56,7 @@ from librispeech import LibriSpeech
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.nn.parallel.distributed import _find_tensors
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformer import Noam
|
from transformer import Noam
|
||||||
@ -611,7 +612,23 @@ def train_one_epoch(
|
|||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
(transducer_loss + aux_loss * params.lambda_aux).backward()
|
if hasattr(model, "module"):
|
||||||
|
out_tensors = list(_find_tensors(aux_loss))
|
||||||
|
model.reducer.prepare_for_backward(out_tensors)
|
||||||
|
model2 = model.module
|
||||||
|
else:
|
||||||
|
model2 = model
|
||||||
|
|
||||||
|
(aux_loss * params.lambda_aux).backward(retain_graph=True)
|
||||||
|
# zero out the grad for decoder and joiner
|
||||||
|
model2.decoder.zero_grad()
|
||||||
|
model2.joiner.zero_grad()
|
||||||
|
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
out_tensors = list(_find_tensors(transducer_loss))
|
||||||
|
model.reducer.prepare_for_backward(out_tensors)
|
||||||
|
|
||||||
|
transducer_loss.backward()
|
||||||
|
|
||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@ -888,7 +905,27 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
(transducer_loss + aux_loss * params.lambda_aux).backward()
|
libri = is_libri(batch["supervisions"]["cut"][0])
|
||||||
|
|
||||||
|
# see https://github.com/pytorch/pytorch/issues/47260#issuecomment-789127532 # noqa
|
||||||
|
# for details of `_find_tensors()` and `prepare_for_backward()`.
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
out_tensors = list(_find_tensors(aux_loss))
|
||||||
|
model.reducer.prepare_for_backward(out_tensors)
|
||||||
|
model2 = model.module
|
||||||
|
else:
|
||||||
|
model2 = model
|
||||||
|
|
||||||
|
(aux_loss * params.lambda_aux).backward(retain_graph=True)
|
||||||
|
# zero out the grad for decoder and joiner
|
||||||
|
model2.decoder.zero_grad()
|
||||||
|
model2.joiner.zero_grad()
|
||||||
|
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
out_tensors = list(_find_tensors(transducer_loss))
|
||||||
|
model.reducer.prepare_for_backward(out_tensors)
|
||||||
|
|
||||||
|
transducer_loss.backward()
|
||||||
|
|
||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user