mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
fix usages of returned losses after adding attention-decoder in zipformer (#1689)
This commit is contained in:
parent
f6febd658e
commit
334beed2af
@ -758,7 +758,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, _ = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -766,6 +766,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss = losses[:2]
|
||||||
|
|
||||||
s = params.simple_loss_scale
|
s = params.simple_loss_scale
|
||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
|
@ -343,7 +343,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, _ = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -351,6 +351,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss = losses[:2]
|
||||||
|
|
||||||
s = params.simple_loss_scale
|
s = params.simple_loss_scale
|
||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
|
@ -814,7 +814,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -822,6 +822,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -449,7 +449,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -457,6 +457,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -803,7 +803,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -811,6 +811,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -806,7 +806,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -814,6 +814,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -787,7 +787,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -795,6 +795,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -55,7 +55,6 @@ It supports training with:
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import random
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -804,7 +803,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -812,6 +811,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -893,7 +893,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -901,6 +901,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -890,7 +890,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -898,6 +898,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -903,7 +903,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -911,6 +911,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -792,7 +792,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -800,6 +800,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -754,7 +754,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, _ = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -762,6 +762,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss = losses[:2]
|
||||||
|
|
||||||
s = params.simple_loss_scale
|
s = params.simple_loss_scale
|
||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
|
@ -832,7 +832,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -840,6 +840,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -814,7 +814,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -822,6 +822,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -59,7 +59,6 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import optim
|
import optim
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -791,7 +790,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -799,6 +798,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -67,7 +67,6 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import SPGISpeechAsrDataModule
|
from asr_datamodule import SPGISpeechAsrDataModule
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import AsrModel
|
from model import AsrModel
|
||||||
@ -792,7 +791,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -800,6 +799,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -758,7 +758,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, _ = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -766,6 +766,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss = losses[:2]
|
||||||
|
|
||||||
s = params.simple_loss_scale
|
s = params.simple_loss_scale
|
||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
|
@ -70,8 +70,7 @@ import copy
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from typing import List, Optional, Tuple, Union
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import optim
|
import optim
|
||||||
@ -80,7 +79,6 @@ import torch.multiprocessing as mp
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import WenetSpeechAsrDataModule
|
from asr_datamodule import WenetSpeechAsrDataModule
|
||||||
from lhotse.cut import Cut, CutSet
|
from lhotse.cut import Cut, CutSet
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -103,14 +101,13 @@ from train import (
|
|||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
save_checkpoint_with_global_batch_idx,
|
save_checkpoint_with_global_batch_idx,
|
||||||
update_averaged_model,
|
update_averaged_model,
|
||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.err import raise_grad_scale_is_too_small_error
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
@ -296,7 +293,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -304,6 +301,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
@ -344,40 +342,6 @@ def compute_loss(
|
|||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
|
|
||||||
def compute_validation_loss(
|
|
||||||
params: AttributeDict,
|
|
||||||
model: Union[nn.Module, DDP],
|
|
||||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
|
||||||
world_size: int = 1,
|
|
||||||
) -> MetricsTracker:
|
|
||||||
"""Run the validation process."""
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
tot_loss = MetricsTracker()
|
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
|
||||||
loss, loss_info = compute_loss(
|
|
||||||
params=params,
|
|
||||||
model=model,
|
|
||||||
graph_compiler=graph_compiler,
|
|
||||||
batch=batch,
|
|
||||||
is_training=False,
|
|
||||||
)
|
|
||||||
assert loss.requires_grad is False
|
|
||||||
tot_loss = tot_loss + loss_info
|
|
||||||
|
|
||||||
if world_size > 1:
|
|
||||||
tot_loss.reduce(loss.device)
|
|
||||||
|
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
|
||||||
if loss_value < params.best_valid_loss:
|
|
||||||
params.best_valid_epoch = params.cur_epoch
|
|
||||||
params.best_valid_loss = loss_value
|
|
||||||
|
|
||||||
return tot_loss
|
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
|
@ -815,7 +815,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, _ = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -823,6 +823,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss = losses[:2]
|
||||||
|
|
||||||
s = params.simple_loss_scale
|
s = params.simple_loss_scale
|
||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
|
Loading…
x
Reference in New Issue
Block a user