mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
fix usages of returned losses after adding attention-decoder in zipformer (#1689)
This commit is contained in:
parent
f6febd658e
commit
334beed2af
@ -758,7 +758,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, _ = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -766,6 +766,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss = losses[:2]
|
||||
|
||||
s = params.simple_loss_scale
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
|
@ -343,7 +343,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, _ = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -351,6 +351,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss = losses[:2]
|
||||
|
||||
s = params.simple_loss_scale
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
|
@ -814,7 +814,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -822,6 +822,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -449,7 +449,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -457,6 +457,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -803,7 +803,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -811,6 +811,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -806,7 +806,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -814,6 +814,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -787,7 +787,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -795,6 +795,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -55,7 +55,6 @@ It supports training with:
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
@ -804,7 +803,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -812,6 +811,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -893,7 +893,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -901,6 +901,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -890,7 +890,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -898,6 +898,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -903,7 +903,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -911,6 +911,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -792,7 +792,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -800,6 +800,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -754,7 +754,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, _ = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -762,6 +762,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss = losses[:2]
|
||||
|
||||
s = params.simple_loss_scale
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
|
@ -832,7 +832,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -840,6 +840,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -814,7 +814,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -822,6 +822,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -59,7 +59,6 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
@ -791,7 +790,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -799,6 +798,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -67,7 +67,6 @@ import torch.nn as nn
|
||||
from asr_datamodule import SPGISpeechAsrDataModule
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import AsrModel
|
||||
@ -792,7 +791,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -800,6 +799,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -758,7 +758,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, _ = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -766,6 +766,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss = losses[:2]
|
||||
|
||||
s = params.simple_loss_scale
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
|
@ -70,8 +70,7 @@ import copy
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
@ -80,7 +79,6 @@ import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import WenetSpeechAsrDataModule
|
||||
from lhotse.cut import Cut, CutSet
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
@ -103,14 +101,13 @@ from train import (
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
@ -296,7 +293,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -304,6 +301,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
@ -344,40 +342,6 @@ def compute_loss(
|
||||
return loss, info
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
|
@ -815,7 +815,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, _ = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -823,6 +823,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss = losses[:2]
|
||||
|
||||
s = params.simple_loss_scale
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
|
Loading…
x
Reference in New Issue
Block a user