Disable gradient computation in evaluation mode.

This commit is contained in:
Fangjun Kuang 2021-07-29 20:37:31 +08:00
parent acc63a9172
commit b94d97da37

View File

@ -13,21 +13,17 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from conformer import Conformer
from transformer import Noam
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.dist import cleanup_dist, setup_dist
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
@ -194,7 +190,10 @@ def load_checkpoint_if_available(
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename, model=model, optimizer=optimizer, scheduler=scheduler,
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
@ -312,13 +311,14 @@ def compute_loss(
)
if params.att_rate != 0.0:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
with torch.set_grad_enabled(is_training):
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
else:
loss = ctc_loss
@ -431,7 +431,6 @@ def train_one_epoch(
optimizer.zero_grad()
loss.backward()
clip_grad_value_(model.parameters(), 5.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
@ -575,7 +574,10 @@ def run(rank, world_size, args):
)
save_checkpoint(
params=params, model=model, optimizer=optimizer, rank=rank,
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")