fix usages of returned losses after adding attention-decoder in zipformer (#1689)

This commit is contained in:
Zengwei Yao 2024-07-12 16:50:58 +08:00 committed by GitHub
parent f6febd658e
commit 334beed2af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 42 additions and 62 deletions

View File

@ -758,7 +758,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -766,6 +766,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start

View File

@ -343,7 +343,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -351,6 +351,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start

View File

@ -814,7 +814,7 @@ def compute_loss(
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -822,6 +822,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0

View File

@ -449,7 +449,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -457,6 +457,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0

View File

@ -803,7 +803,7 @@ def compute_loss(
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -811,6 +811,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0

View File

@ -806,7 +806,7 @@ def compute_loss(
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -814,6 +814,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0

View File

@ -787,7 +787,7 @@ def compute_loss(
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -795,6 +795,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0

View File

@ -55,7 +55,6 @@ It supports training with:
import argparse
import copy
import logging
import random
import warnings
from pathlib import Path
from shutil import copyfile
@ -804,7 +803,7 @@ def compute_loss(
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -812,6 +811,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0

View File

@ -893,7 +893,7 @@ def compute_loss(
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -901,6 +901,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0

View File

@ -890,7 +890,7 @@ def compute_loss(
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -898,6 +898,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0

View File

@ -903,7 +903,7 @@ def compute_loss(
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -911,6 +911,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0

View File

@ -792,7 +792,7 @@ def compute_loss(
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -800,6 +800,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0

View File

@ -754,7 +754,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -762,6 +762,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start

View File

@ -832,7 +832,7 @@ def compute_loss(
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -840,6 +840,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0

View File

@ -814,7 +814,7 @@ def compute_loss(
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -822,6 +822,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0

View File

@ -59,7 +59,6 @@ from typing import Any, Dict, Optional, Tuple, Union
import k2
import optim
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
@ -791,7 +790,7 @@ def compute_loss(
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -799,6 +798,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0

View File

@ -67,7 +67,6 @@ import torch.nn as nn
from asr_datamodule import SPGISpeechAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import AsrModel
@ -792,7 +791,7 @@ def compute_loss(
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -800,6 +799,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0

View File

@ -758,7 +758,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -766,6 +766,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start

View File

@ -70,8 +70,7 @@ import copy
import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import k2
import optim
@ -80,7 +79,6 @@ import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule
from lhotse.cut import Cut, CutSet
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from optim import Eden, ScaledAdam
from torch import Tensor
@ -103,14 +101,13 @@ from train import (
from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
@ -296,7 +293,7 @@ def compute_loss(
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -304,6 +301,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0
@ -344,40 +342,6 @@ def compute_loss(
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
graph_compiler: CharCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],

View File

@ -815,7 +815,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model(
losses = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -823,6 +823,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start