mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Disable gradient computation in evaluation mode.
This commit is contained in:
parent
acc63a9172
commit
b94d97da37
@ -13,21 +13,17 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from conformer import Conformer
|
||||
from transformer import Noam
|
||||
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_value_
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -194,7 +190,10 @@ def load_checkpoint_if_available(
|
||||
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
saved_params = load_checkpoint(
|
||||
filename, model=model, optimizer=optimizer, scheduler=scheduler,
|
||||
filename,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
@ -312,13 +311,14 @@ def compute_loss(
|
||||
)
|
||||
|
||||
if params.att_rate != 0.0:
|
||||
att_loss = model.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
with torch.set_grad_enabled(is_training):
|
||||
att_loss = model.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
|
||||
else:
|
||||
loss = ctc_loss
|
||||
@ -431,7 +431,6 @@ def train_one_epoch(
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
clip_grad_value_(model.parameters(), 5.0)
|
||||
optimizer.step()
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
@ -575,7 +574,10 @@ def run(rank, world_size, args):
|
||||
)
|
||||
|
||||
save_checkpoint(
|
||||
params=params, model=model, optimizer=optimizer, rank=rank,
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
Loading…
x
Reference in New Issue
Block a user