mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Disable gradient computation in evaluation mode.
This commit is contained in:
parent
acc63a9172
commit
b94d97da37
@ -13,21 +13,17 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from transformer import Noam
|
|
||||||
|
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_value_
|
|
||||||
from torch.optim.lr_scheduler import StepLR
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from transformer import Noam
|
||||||
|
|
||||||
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -194,7 +190,10 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
saved_params = load_checkpoint(
|
saved_params = load_checkpoint(
|
||||||
filename, model=model, optimizer=optimizer, scheduler=scheduler,
|
filename,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = [
|
keys = [
|
||||||
@ -312,6 +311,7 @@ def compute_loss(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if params.att_rate != 0.0:
|
if params.att_rate != 0.0:
|
||||||
|
with torch.set_grad_enabled(is_training):
|
||||||
att_loss = model.decoder_forward(
|
att_loss = model.decoder_forward(
|
||||||
encoder_memory,
|
encoder_memory,
|
||||||
memory_mask,
|
memory_mask,
|
||||||
@ -431,7 +431,6 @@ def train_one_epoch(
|
|||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
clip_grad_value_(model.parameters(), 5.0)
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
loss_cpu = loss.detach().cpu().item()
|
loss_cpu = loss.detach().cpu().item()
|
||||||
@ -575,7 +574,10 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params, model=model, optimizer=optimizer, rank=rank,
|
params=params,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user