fix usages of returned losses after adding attention-decoder in zipformer (#1689)

This commit is contained in:
Zengwei Yao 2024-07-12 16:50:58 +08:00 committed by GitHub
parent f6febd658e
commit 334beed2af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 42 additions and 62 deletions

View File

@ -758,7 +758,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -766,6 +766,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start

View File

@ -343,7 +343,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -351,6 +351,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start

View File

@ -814,7 +814,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -822,6 +822,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -449,7 +449,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -457,6 +457,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -803,7 +803,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -811,6 +811,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -806,7 +806,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -814,6 +814,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -787,7 +787,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -795,6 +795,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -55,7 +55,6 @@ It supports training with:
import argparse import argparse
import copy import copy
import logging import logging
import random
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -804,7 +803,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -812,6 +811,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -893,7 +893,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -901,6 +901,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -890,7 +890,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -898,6 +898,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -903,7 +903,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -911,6 +911,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -792,7 +792,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -800,6 +800,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -754,7 +754,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -762,6 +762,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start

View File

@ -832,7 +832,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -840,6 +840,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -814,7 +814,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -822,6 +822,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -59,7 +59,6 @@ from typing import Any, Dict, Optional, Tuple, Union
import k2 import k2
import optim import optim
import sentencepiece as spm
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
@ -791,7 +790,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -799,6 +798,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -67,7 +67,6 @@ import torch.nn as nn
from asr_datamodule import SPGISpeechAsrDataModule from asr_datamodule import SPGISpeechAsrDataModule
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import AsrModel from model import AsrModel
@ -792,7 +791,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -800,6 +799,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -758,7 +758,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -766,6 +766,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start

View File

@ -70,8 +70,7 @@ import copy
import logging import logging
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import k2 import k2
import optim import optim
@ -80,7 +79,6 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule from asr_datamodule import WenetSpeechAsrDataModule
from lhotse.cut import Cut, CutSet from lhotse.cut import Cut, CutSet
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
@ -103,14 +101,13 @@ from train import (
from icefall import diagnostics from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import ( from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx, save_checkpoint_with_global_batch_idx,
update_averaged_model, update_averaged_model,
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
@ -296,7 +293,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -304,6 +301,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0
@ -344,40 +342,6 @@ def compute_loss(
return loss, info return loss, info
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
graph_compiler: CharCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch( def train_one_epoch(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],

View File

@ -815,7 +815,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -823,6 +823,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start