Disable gradient computation in evaluation mode.

This commit is contained in:
Fangjun Kuang 2021-07-29 20:37:31 +08:00
parent acc63a9172
commit b94d97da37

View File

@ -13,21 +13,17 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.optim as optim
from conformer import Conformer from conformer import Conformer
from transformer import Noam
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -194,7 +190,10 @@ def load_checkpoint_if_available(
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint( saved_params = load_checkpoint(
filename, model=model, optimizer=optimizer, scheduler=scheduler, filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
) )
keys = [ keys = [
@ -312,6 +311,7 @@ def compute_loss(
) )
if params.att_rate != 0.0: if params.att_rate != 0.0:
with torch.set_grad_enabled(is_training):
att_loss = model.decoder_forward( att_loss = model.decoder_forward(
encoder_memory, encoder_memory,
memory_mask, memory_mask,
@ -431,7 +431,6 @@ def train_one_epoch(
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
clip_grad_value_(model.parameters(), 5.0)
optimizer.step() optimizer.step()
loss_cpu = loss.detach().cpu().item() loss_cpu = loss.detach().cpu().item()
@ -575,7 +574,10 @@ def run(rank, world_size, args):
) )
save_checkpoint( save_checkpoint(
params=params, model=model, optimizer=optimizer, rank=rank, params=params,
model=model,
optimizer=optimizer,
rank=rank,
) )
logging.info("Done!") logging.info("Done!")