From d65187ec5245457a43e352f4c0c9930ab2d98225 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 11 Jul 2024 14:45:35 +0800 Subject: [PATCH 1/4] Small fix (#1686) --- egs/librispeech/ASR/zipformer/scaling.py | 5 +++-- egs/librispeech/ASR/zipformer/train.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e7c3f4ab1..3c7e0fa4e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -636,8 +636,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): ) def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: - """ - Forward function. Args: + """Forward function. + + Args: x: a Tensor of shape (batch_size, channels, seq_len) chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. """ diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 3797de484..9b6f4a93a 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -406,7 +406,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -429,7 +429,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -848,7 +848,7 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. - warmup: a floating point value which increases throughout training; + warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device From 19048e155b5a07b4dd4d1795815a1a0cd4584e25 Mon Sep 17 00:00:00 2001 From: Teo Wen Shen <36886809+teowenshen@users.noreply.github.com> Date: Thu, 11 Jul 2024 16:12:30 +0900 Subject: [PATCH 2/4] Cast grad_scale in whiten to float (#1663) * cast grad_scale in whiten to float * fix cast in zipformer_lora --- egs/librispeech/ASR/zipformer/scaling.py | 4 ++-- egs/librispeech/ASR/zipformer_lora/scaling.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 3c7e0fa4e..164cc7bfd 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1033,7 +1033,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * ( + scale = float(w.grad_scale) * ( x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) ) @@ -1075,7 +1075,7 @@ class Whiten(nn.Module): super(Whiten, self).__init__() assert num_groups >= 1 assert float(whitening_limit) >= 1 - assert grad_scale >= 0 + assert float(grad_scale) >= 0 self.num_groups = num_groups self.whitening_limit = whitening_limit self.grad_scale = grad_scale diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py index 3149db9f3..8d7aa8027 100644 --- a/egs/librispeech/ASR/zipformer_lora/scaling.py +++ b/egs/librispeech/ASR/zipformer_lora/scaling.py @@ -1137,7 +1137,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * ( + scale = float(w.grad_scale) * ( x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) ) @@ -1179,7 +1179,7 @@ class Whiten(nn.Module): super(Whiten, self).__init__() assert num_groups >= 1 assert float(whitening_limit) >= 1 - assert grad_scale >= 0 + assert float(grad_scale) >= 0 self.num_groups = num_groups self.whitening_limit = whitening_limit self.grad_scale = grad_scale From f6febd658eb5f1b52771c0b88a5d1205e0d40370 Mon Sep 17 00:00:00 2001 From: Ziwei Li <99643269+NLPvv@users.noreply.github.com> Date: Fri, 12 Jul 2024 14:42:00 +0800 Subject: [PATCH 3/4] "-" replace "_" fix writing error (#1687) --- egs/gigaspeech/KWS/run.sh | 4 ++-- egs/wenetspeech/KWS/run.sh | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/gigaspeech/KWS/run.sh b/egs/gigaspeech/KWS/run.sh index bd562ce1c..303abd718 100755 --- a/egs/gigaspeech/KWS/run.sh +++ b/egs/gigaspeech/KWS/run.sh @@ -90,7 +90,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then --encoder-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 - python ./zipformer/export_onnx_streaming.py \ + python ./zipformer/export-onnx-streaming.py \ --exp-dir zipformer/exp \ --tokens data/lang_bpe_500/tokens.txt \ --epoch 12 \ @@ -184,7 +184,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then --encoder-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 - python ./zipformer/export_onnx_streaming.py \ + python ./zipformer/export-onnx-streaming.py \ --exp-dir zipformer/exp_finetune \ --tokens data/lang_bpe_500/tokens.txt \ --epoch 10 \ diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh index 232ee039a..8472b8531 100755 --- a/egs/wenetspeech/KWS/run.sh +++ b/egs/wenetspeech/KWS/run.sh @@ -91,7 +91,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then --encoder-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 - python ./zipformer/export_onnx_streaming.py \ + python ./zipformer/export-onnx-streaming.py \ --exp-dir zipformer/exp \ --tokens data/lang_partial_tone/tokens.txt \ --epoch 18 \ @@ -187,7 +187,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then --encoder-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 - python ./zipformer/export_onnx_streaming.py \ + python ./zipformer/export-onnx-streaming.py \ --exp-dir zipformer/exp_finetune \ --tokens data/lang_partial_tone/tokens.txt \ --epoch 10 \ From 334beed2af5212b1b2b8ca112893b120e83d0516 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Fri, 12 Jul 2024 16:50:58 +0800 Subject: [PATCH 4/4] fix usages of returned losses after adding attention-decoder in zipformer (#1689) --- egs/aishell/ASR/zipformer/train.py | 3 +- egs/aishell/ASR/zipformer/train_bbpe.py | 3 +- egs/commonvoice/ASR/zipformer/train.py | 3 +- egs/commonvoice/ASR/zipformer/train_char.py | 3 +- egs/gigaspeech/ASR/zipformer/train.py | 3 +- egs/gigaspeech/KWS/zipformer/train.py | 3 +- egs/ksponspeech/ASR/zipformer/train.py | 3 +- egs/libriheavy/ASR/zipformer/train.py | 4 +- egs/librispeech/ASR/zipformer/finetune.py | 3 +- .../ASR/zipformer_adapter/train.py | 3 +- .../ASR/zipformer_lora/finetune.py | 3 +- egs/librispeech/ASR/zipformer_lora/train.py | 3 +- egs/mdcc/ASR/zipformer/train.py | 3 +- egs/multi_zh-hans/ASR/zipformer/train.py | 3 +- egs/multi_zh_en/ASR/zipformer/train.py | 3 +- egs/reazonspeech/ASR/zipformer/train.py | 4 +- egs/spgispeech/ASR/zipformer/train.py | 4 +- egs/wenetspeech/ASR/zipformer/train.py | 3 +- egs/wenetspeech/KWS/zipformer/finetune.py | 44 ++----------------- egs/wenetspeech/KWS/zipformer/train.py | 3 +- 20 files changed, 42 insertions(+), 62 deletions(-) diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index a25979226..cd253c597 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -758,7 +758,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -766,6 +766,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss = losses[:2] s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py index 0713c5787..46a5506db 100755 --- a/egs/aishell/ASR/zipformer/train_bbpe.py +++ b/egs/aishell/ASR/zipformer/train_bbpe.py @@ -343,7 +343,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -351,6 +351,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss = losses[:2] s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start diff --git a/egs/commonvoice/ASR/zipformer/train.py b/egs/commonvoice/ASR/zipformer/train.py index 5cda9bfd4..271014db0 100755 --- a/egs/commonvoice/ASR/zipformer/train.py +++ b/egs/commonvoice/ASR/zipformer/train.py @@ -814,7 +814,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -822,6 +822,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/commonvoice/ASR/zipformer/train_char.py b/egs/commonvoice/ASR/zipformer/train_char.py index a780bbbbc..0aa7856cc 100755 --- a/egs/commonvoice/ASR/zipformer/train_char.py +++ b/egs/commonvoice/ASR/zipformer/train_char.py @@ -449,7 +449,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -457,6 +457,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index f0ad98147..4c122effe 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -803,7 +803,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -811,6 +811,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py index a4d670169..39d8fc6cd 100755 --- a/egs/gigaspeech/KWS/zipformer/train.py +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -806,7 +806,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -814,6 +814,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/ksponspeech/ASR/zipformer/train.py b/egs/ksponspeech/ASR/zipformer/train.py index b612b6835..485ea69c9 100755 --- a/egs/ksponspeech/ASR/zipformer/train.py +++ b/egs/ksponspeech/ASR/zipformer/train.py @@ -787,7 +787,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -795,6 +795,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py index 8d4d9d067..357e8a827 100644 --- a/egs/libriheavy/ASR/zipformer/train.py +++ b/egs/libriheavy/ASR/zipformer/train.py @@ -55,7 +55,6 @@ It supports training with: import argparse import copy import logging -import random import warnings from pathlib import Path from shutil import copyfile @@ -804,7 +803,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -812,6 +811,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py index 2f7ec0c17..2ff631914 100755 --- a/egs/librispeech/ASR/zipformer/finetune.py +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -893,7 +893,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -901,6 +901,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index 6c55896a8..3511590da 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -890,7 +890,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -898,6 +898,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py index 0464cf65c..3f36f229f 100755 --- a/egs/librispeech/ASR/zipformer_lora/finetune.py +++ b/egs/librispeech/ASR/zipformer_lora/finetune.py @@ -903,7 +903,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -911,6 +911,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/librispeech/ASR/zipformer_lora/train.py b/egs/librispeech/ASR/zipformer_lora/train.py index 3ccf7d2f1..9ab214e86 100755 --- a/egs/librispeech/ASR/zipformer_lora/train.py +++ b/egs/librispeech/ASR/zipformer_lora/train.py @@ -792,7 +792,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -800,6 +800,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/mdcc/ASR/zipformer/train.py b/egs/mdcc/ASR/zipformer/train.py index 2fae66844..730db7718 100755 --- a/egs/mdcc/ASR/zipformer/train.py +++ b/egs/mdcc/ASR/zipformer/train.py @@ -754,7 +754,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -762,6 +762,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss = losses[:2] s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index 1fc4c35c1..3dbfc48eb 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -832,7 +832,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -840,6 +840,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py index 5dba584f7..04bb41214 100755 --- a/egs/multi_zh_en/ASR/zipformer/train.py +++ b/egs/multi_zh_en/ASR/zipformer/train.py @@ -814,7 +814,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -822,6 +822,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/reazonspeech/ASR/zipformer/train.py b/egs/reazonspeech/ASR/zipformer/train.py index 8c6f4bb9a..30bd3efba 100755 --- a/egs/reazonspeech/ASR/zipformer/train.py +++ b/egs/reazonspeech/ASR/zipformer/train.py @@ -59,7 +59,6 @@ from typing import Any, Dict, Optional, Tuple, Union import k2 import optim -import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn @@ -791,7 +790,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -799,6 +798,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/spgispeech/ASR/zipformer/train.py b/egs/spgispeech/ASR/zipformer/train.py index ed66ca29b..dfc21c968 100755 --- a/egs/spgispeech/ASR/zipformer/train.py +++ b/egs/spgispeech/ASR/zipformer/train.py @@ -67,7 +67,6 @@ import torch.nn as nn from asr_datamodule import SPGISpeechAsrDataModule from decoder import Decoder from joiner import Joiner -from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -792,7 +791,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -800,6 +799,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py index 3d3762916..25b16f632 100755 --- a/egs/wenetspeech/ASR/zipformer/train.py +++ b/egs/wenetspeech/ASR/zipformer/train.py @@ -758,7 +758,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -766,6 +766,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss = losses[:2] s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py index 3ad16fd11..d19172b38 100755 --- a/egs/wenetspeech/KWS/zipformer/finetune.py +++ b/egs/wenetspeech/KWS/zipformer/finetune.py @@ -70,8 +70,7 @@ import copy import logging import warnings from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import k2 import optim @@ -80,7 +79,6 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import WenetSpeechAsrDataModule from lhotse.cut import Cut, CutSet -from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from torch import Tensor @@ -103,14 +101,13 @@ from train import ( from icefall import diagnostics from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, update_averaged_model, ) from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks from icefall.lexicon import Lexicon @@ -296,7 +293,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -304,6 +301,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss, ctc_loss = losses[:3] loss = 0.0 @@ -344,40 +342,6 @@ def compute_loss( return loss, info -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py index eddec7303..40960c2ae 100755 --- a/egs/wenetspeech/KWS/zipformer/train.py +++ b/egs/wenetspeech/KWS/zipformer/train.py @@ -815,7 +815,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + losses = model( x=feature, x_lens=feature_lens, y=y, @@ -823,6 +823,7 @@ def compute_loss( am_scale=params.am_scale, lm_scale=params.lm_scale, ) + simple_loss, pruned_loss = losses[:2] s = params.simple_loss_scale # take down the scale on the simple loss from 1.0 at the start