Merge branch 'k2-fsa:master' into dev/zipformer_lstm

This commit is contained in:
Yifan Yang 2024-07-14 00:32:21 +08:00 committed by GitHub
commit 359ffce6c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 56 additions and 75 deletions

View File

@ -758,7 +758,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -766,6 +766,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start

View File

@ -343,7 +343,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -351,6 +351,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start

View File

@ -814,7 +814,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -822,6 +822,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -449,7 +449,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -457,6 +457,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -803,7 +803,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -811,6 +811,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -90,7 +90,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
--encoder-dim 128,128,128,128,128,128 \ --encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 --encoder-unmasked-dim 128,128,128,128,128,128
python ./zipformer/export_onnx_streaming.py \ python ./zipformer/export-onnx-streaming.py \
--exp-dir zipformer/exp \ --exp-dir zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 12 \ --epoch 12 \
@ -184,7 +184,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
--encoder-dim 128,128,128,128,128,128 \ --encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 --encoder-unmasked-dim 128,128,128,128,128,128
python ./zipformer/export_onnx_streaming.py \ python ./zipformer/export-onnx-streaming.py \
--exp-dir zipformer/exp_finetune \ --exp-dir zipformer/exp_finetune \
--tokens data/lang_bpe_500/tokens.txt \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 10 \ --epoch 10 \

View File

@ -806,7 +806,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -814,6 +814,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -787,7 +787,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -795,6 +795,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -55,7 +55,6 @@ It supports training with:
import argparse import argparse
import copy import copy
import logging import logging
import random
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
@ -804,7 +803,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -812,6 +811,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -893,7 +893,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -901,6 +901,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -636,8 +636,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
) )
def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
""" """Forward function.
Forward function. Args:
Args:
x: a Tensor of shape (batch_size, channels, seq_len) x: a Tensor of shape (batch_size, channels, seq_len)
chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
""" """
@ -1032,7 +1033,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
w.prob = w.max_prob w.prob = w.max_prob
metric.backward() metric.backward()
penalty_grad = x_detached.grad penalty_grad = x_detached.grad
scale = w.grad_scale * ( scale = float(w.grad_scale) * (
x_grad.to(torch.float32).norm() x_grad.to(torch.float32).norm()
/ (penalty_grad.norm() + 1.0e-20) / (penalty_grad.norm() + 1.0e-20)
) )
@ -1074,7 +1075,7 @@ class Whiten(nn.Module):
super(Whiten, self).__init__() super(Whiten, self).__init__()
assert num_groups >= 1 assert num_groups >= 1
assert float(whitening_limit) >= 1 assert float(whitening_limit) >= 1
assert grad_scale >= 0 assert float(grad_scale) >= 0
self.num_groups = num_groups self.num_groups = num_groups
self.whitening_limit = whitening_limit self.whitening_limit = whitening_limit
self.grad_scale = grad_scale self.grad_scale = grad_scale

View File

@ -406,7 +406,7 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -429,7 +429,7 @@ def get_parser():
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" "part.", help="The scale to smooth the loss with am (output of encoder network) part.",
) )
parser.add_argument( parser.add_argument(
@ -848,7 +848,7 @@ def compute_loss(
True for training. False for validation. When it is True, this True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it function enables autograd during computation; when it is False, it
disables autograd. disables autograd.
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device

View File

@ -890,7 +890,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -898,6 +898,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -903,7 +903,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -911,6 +911,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -1137,7 +1137,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
w.prob = w.max_prob w.prob = w.max_prob
metric.backward() metric.backward()
penalty_grad = x_detached.grad penalty_grad = x_detached.grad
scale = w.grad_scale * ( scale = float(w.grad_scale) * (
x_grad.to(torch.float32).norm() x_grad.to(torch.float32).norm()
/ (penalty_grad.norm() + 1.0e-20) / (penalty_grad.norm() + 1.0e-20)
) )
@ -1179,7 +1179,7 @@ class Whiten(nn.Module):
super(Whiten, self).__init__() super(Whiten, self).__init__()
assert num_groups >= 1 assert num_groups >= 1
assert float(whitening_limit) >= 1 assert float(whitening_limit) >= 1
assert grad_scale >= 0 assert float(grad_scale) >= 0
self.num_groups = num_groups self.num_groups = num_groups
self.whitening_limit = whitening_limit self.whitening_limit = whitening_limit
self.grad_scale = grad_scale self.grad_scale = grad_scale

View File

@ -792,7 +792,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -800,6 +800,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -754,7 +754,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -762,6 +762,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start

View File

@ -832,7 +832,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -840,6 +840,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -814,7 +814,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -822,6 +822,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -59,7 +59,6 @@ from typing import Any, Dict, Optional, Tuple, Union
import k2 import k2
import optim import optim
import sentencepiece as spm
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
@ -791,7 +790,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -799,6 +798,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -67,7 +67,6 @@ import torch.nn as nn
from asr_datamodule import SPGISpeechAsrDataModule from asr_datamodule import SPGISpeechAsrDataModule
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import AsrModel from model import AsrModel
@ -792,7 +791,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -800,6 +799,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0

View File

@ -758,7 +758,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -766,6 +766,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start

View File

@ -91,7 +91,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
--encoder-dim 128,128,128,128,128,128 \ --encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 --encoder-unmasked-dim 128,128,128,128,128,128
python ./zipformer/export_onnx_streaming.py \ python ./zipformer/export-onnx-streaming.py \
--exp-dir zipformer/exp \ --exp-dir zipformer/exp \
--tokens data/lang_partial_tone/tokens.txt \ --tokens data/lang_partial_tone/tokens.txt \
--epoch 18 \ --epoch 18 \
@ -187,7 +187,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
--encoder-dim 128,128,128,128,128,128 \ --encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 --encoder-unmasked-dim 128,128,128,128,128,128
python ./zipformer/export_onnx_streaming.py \ python ./zipformer/export-onnx-streaming.py \
--exp-dir zipformer/exp_finetune \ --exp-dir zipformer/exp_finetune \
--tokens data/lang_partial_tone/tokens.txt \ --tokens data/lang_partial_tone/tokens.txt \
--epoch 10 \ --epoch 10 \

View File

@ -70,8 +70,7 @@ import copy
import logging import logging
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import k2 import k2
import optim import optim
@ -80,7 +79,6 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule from asr_datamodule import WenetSpeechAsrDataModule
from lhotse.cut import Cut, CutSet from lhotse.cut import Cut, CutSet
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
@ -103,14 +101,13 @@ from train import (
from icefall import diagnostics from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import ( from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx, save_checkpoint_with_global_batch_idx,
update_averaged_model, update_averaged_model,
) )
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
@ -296,7 +293,7 @@ def compute_loss(
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -304,6 +301,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss, ctc_loss = losses[:3]
loss = 0.0 loss = 0.0
@ -344,40 +342,6 @@ def compute_loss(
return loss, info return loss, info
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
graph_compiler: CharCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch( def train_one_epoch(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],

View File

@ -815,7 +815,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, _ = model( losses = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -823,6 +823,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start