diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index a25979226..cd253c597 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -758,7 +758,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -766,6 +766,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss = losses[:2] s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py index 0713c5787..46a5506db 100755 --- a/egs/aishell/ASR/zipformer/train_bbpe.py +++ b/egs/aishell/ASR/zipformer/train_bbpe.py @@ -343,7 +343,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -351,6 +351,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss = losses[:2] s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start diff --git a/egs/commonvoice/ASR/zipformer/train.py b/egs/commonvoice/ASR/zipformer/train.py index 5cda9bfd4..271014db0 100755 --- a/egs/commonvoice/ASR/zipformer/train.py +++ b/egs/commonvoice/ASR/zipformer/train.py @@ -814,7 +814,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -822,6 +822,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/commonvoice/ASR/zipformer/train_char.py b/egs/commonvoice/ASR/zipformer/train_char.py index a780bbbbc..0aa7856cc 100755 --- a/egs/commonvoice/ASR/zipformer/train_char.py +++ b/egs/commonvoice/ASR/zipformer/train_char.py @@ -449,7 +449,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -457,6 +457,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index f0ad98147..4c122effe 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -803,7 +803,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -811,6 +811,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py index a4d670169..39d8fc6cd 100755 --- a/egs/gigaspeech/KWS/zipformer/train.py +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -806,7 +806,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -814,6 +814,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/ksponspeech/ASR/zipformer/train.py b/egs/ksponspeech/ASR/zipformer/train.py index b612b6835..485ea69c9 100755 --- a/egs/ksponspeech/ASR/zipformer/train.py +++ b/egs/ksponspeech/ASR/zipformer/train.py @@ -787,7 +787,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -795,6 +795,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py index 8d4d9d067..357e8a827 100644 --- a/egs/libriheavy/ASR/zipformer/train.py +++ b/egs/libriheavy/ASR/zipformer/train.py @@ -55,7 +55,6 @@ It supports training with: import argparse import copy import logging -import random import warnings from pathlib import Path from shutil import copyfile @@ -804,7 +803,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -812,6 +811,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py index 2f7ec0c17..2ff631914 100755 --- a/egs/librispeech/ASR/zipformer/finetune.py +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -893,7 +893,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -901,6 +901,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index 6c55896a8..3511590da 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -890,7 +890,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -898,6 +898,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py index 0464cf65c..3f36f229f 100755 --- a/egs/librispeech/ASR/zipformer_lora/finetune.py +++ b/egs/librispeech/ASR/zipformer_lora/finetune.py @@ -903,7 +903,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -911,6 +911,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/librispeech/ASR/zipformer_lora/train.py b/egs/librispeech/ASR/zipformer_lora/train.py index 3ccf7d2f1..9ab214e86 100755 --- a/egs/librispeech/ASR/zipformer_lora/train.py +++ b/egs/librispeech/ASR/zipformer_lora/train.py @@ -792,7 +792,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -800,6 +800,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/mdcc/ASR/zipformer/train.py b/egs/mdcc/ASR/zipformer/train.py index 2fae66844..730db7718 100755 --- a/egs/mdcc/ASR/zipformer/train.py +++ b/egs/mdcc/ASR/zipformer/train.py @@ -754,7 +754,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -762,6 +762,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss = losses[:2] s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index 1fc4c35c1..3dbfc48eb 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -832,7 +832,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -840,6 +840,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py index 5dba584f7..04bb41214 100755 --- a/egs/multi_zh_en/ASR/zipformer/train.py +++ b/egs/multi_zh_en/ASR/zipformer/train.py @@ -814,7 +814,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -822,6 +822,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/reazonspeech/ASR/zipformer/train.py b/egs/reazonspeech/ASR/zipformer/train.py index 8c6f4bb9a..30bd3efba 100755 --- a/egs/reazonspeech/ASR/zipformer/train.py +++ b/egs/reazonspeech/ASR/zipformer/train.py @@ -59,7 +59,6 @@ from typing import Any, Dict, Optional, Tuple, Union import k2 import optim -import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn @@ -791,7 +790,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -799,6 +798,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/spgispeech/ASR/zipformer/train.py b/egs/spgispeech/ASR/zipformer/train.py index ed66ca29b..dfc21c968 100755 --- a/egs/spgispeech/ASR/zipformer/train.py +++ b/egs/spgispeech/ASR/zipformer/train.py @@ -67,7 +67,6 @@ import torch.nn as nn from asr_datamodule import SPGISpeechAsrDataModule from decoder import Decoder from joiner import Joiner -from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -792,7 +791,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -800,6 +799,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py index 3d3762916..25b16f632 100755 --- a/egs/wenetspeech/ASR/zipformer/train.py +++ b/egs/wenetspeech/ASR/zipformer/train.py @@ -758,7 +758,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -766,6 +766,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss = losses[:2] s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py index 3ad16fd11..d19172b38 100755 --- a/egs/wenetspeech/KWS/zipformer/finetune.py +++ b/egs/wenetspeech/KWS/zipformer/finetune.py @@ -70,8 +70,7 @@ import copy import logging import warnings from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import k2 import optim @@ -80,7 +79,6 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import WenetSpeechAsrDataModule from lhotse.cut import Cut, CutSet -from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from torch import Tensor @@ -103,14 +101,13 @@ from train import ( from icefall import diagnostics from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, update_averaged_model, ) from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon @@ -296,7 +293,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -304,6 +301,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 @@ -344,40 +342,6 @@ def compute_loss( return loss, info -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py index eddec7303..40960c2ae 100755 --- a/egs/wenetspeech/KWS/zipformer/train.py +++ b/egs/wenetspeech/KWS/zipformer/train.py @@ -815,7 +815,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -823,6 +823,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss = losses[:2] s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start