From 86bd16d496ecdd7d0d487d1949a137df624e5147 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Wed, 2 Apr 2025 22:10:06 +0800 Subject: [PATCH 1/4] [KWS]Remove graph compiler (#1905) --- egs/wenetspeech/KWS/run.sh | 6 +-- egs/wenetspeech/KWS/zipformer/decode.py | 1 - egs/wenetspeech/KWS/zipformer/finetune.py | 39 +++++--------- egs/wenetspeech/KWS/zipformer/train.py | 65 ++++++++++------------- 4 files changed, 43 insertions(+), 68 deletions(-) diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh index 8472b8531..0af7c1595 100755 --- a/egs/wenetspeech/KWS/run.sh +++ b/egs/wenetspeech/KWS/run.sh @@ -108,7 +108,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 2: Finetune the model" + log "Stage 3: Finetune the model" # The following configuration of lr schedule should work well # You may also tune the following parameters to adjust learning rate schedule @@ -143,7 +143,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 1: Decode the finetuned model." + log "Stage 4: Decode the finetuned model." export CUDA_VISIBLE_DEVICES="0" for t in small large; do python ./zipformer/decode.py \ @@ -170,7 +170,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 2: Export the finetuned model." + log "Stage 5: Export the finetuned model." python ./zipformer/export.py \ --epoch 10 \ diff --git a/egs/wenetspeech/KWS/zipformer/decode.py b/egs/wenetspeech/KWS/zipformer/decode.py index 340a41231..a628c7e58 100755 --- a/egs/wenetspeech/KWS/zipformer/decode.py +++ b/egs/wenetspeech/KWS/zipformer/decode.py @@ -35,7 +35,6 @@ from lhotse.cut import Cut from train import add_model_arguments, get_model, get_params from icefall import ContextGraph -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py index d19172b38..cd437da4c 100755 --- a/egs/wenetspeech/KWS/zipformer/finetune.py +++ b/egs/wenetspeech/KWS/zipformer/finetune.py @@ -90,6 +90,7 @@ from train import ( add_training_arguments, compute_validation_loss, display_and_save_batch, + encode_text, get_adjusted_batch_count, get_model, get_params, @@ -100,7 +101,6 @@ from train import ( ) from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -110,11 +110,11 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, + num_tokens, setup_logger, str2bool, text_to_pinyin, @@ -254,7 +254,6 @@ def load_model_params( def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: @@ -289,7 +288,7 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts, sep="/") + y = [c.supervisions[0].tokens for c in supervisions["cut"]] y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): @@ -347,7 +346,6 @@ def train_one_epoch( model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, @@ -418,7 +416,6 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, - graph_compiler=graph_compiler, batch=batch, is_training=True, ) @@ -436,7 +433,7 @@ def train_one_epoch( optimizer.zero_grad() except: # noqa save_bad_model() - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch(batch, params=params) raise if params.print_diagnostics and batch_idx == 5: @@ -523,7 +520,6 @@ def train_one_epoch( valid_info = compute_validation_loss( params=params, model=model, - graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, ) @@ -576,14 +572,10 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) + token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt") - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 if not params.use_transducer: params.ctc_loss_scale = 1.0 @@ -666,17 +658,10 @@ def run(rank, world_size, args): else: train_cuts = wenetspeech.nihaowenwen_train_cuts() - def encode_text(c: Cut): - # Text normalize for each sample - text = c.supervisions[0].text - text = "/".join( - text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors) - ) - c.supervisions[0].text = text - return c + _encode_text = partial(encode_text, token_table=token_table, params=params) train_cuts = train_cuts.filter(remove_short_utt) - train_cuts = train_cuts.map(encode_text) + train_cuts = train_cuts.map(_encode_text) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -691,7 +676,7 @@ def run(rank, world_size, args): valid_cuts = wenetspeech.nihaowenwen_dev_cuts() valid_cuts = valid_cuts.filter(remove_short_utt) - valid_cuts = valid_cuts.map(encode_text) + valid_cuts = valid_cuts.map(_encode_text) valid_dl = wenetspeech.valid_dataloaders(valid_cuts) if not params.print_diagnostics and params.scan_for_oom_batches: @@ -699,7 +684,6 @@ def run(rank, world_size, args): model=model, train_dl=train_dl, optimizer=optimizer, - graph_compiler=graph_compiler, params=params, ) @@ -724,7 +708,6 @@ def run(rank, world_size, args): model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, - graph_compiler=graph_compiler, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, @@ -760,6 +743,8 @@ def main(): WenetSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.return_cuts = True world_size = args.world_size assert world_size >= 1 diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py index 40960c2ae..5d9d8de36 100755 --- a/egs/wenetspeech/KWS/zipformer/train.py +++ b/egs/wenetspeech/KWS/zipformer/train.py @@ -53,6 +53,7 @@ import argparse import copy import logging import warnings +from functools import partial from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union @@ -79,7 +80,6 @@ from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -90,11 +90,11 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.err import raise_grad_scale_is_too_small_error from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, get_parameter_groups_with_lrs, + num_tokens, setup_logger, str2bool, text_to_pinyin, @@ -776,7 +776,6 @@ def save_checkpoint( def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: @@ -811,7 +810,7 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids(texts, sep="/") + y = [c.supervisions[0].tokens for c in supervisions["cut"]] y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): @@ -859,7 +858,6 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, model: Union[nn.Module, DDP], - graph_compiler: CharCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> MetricsTracker: @@ -872,7 +870,6 @@ def compute_validation_loss( loss, loss_info = compute_loss( params=params, model=model, - graph_compiler=graph_compiler, batch=batch, is_training=False, ) @@ -895,7 +892,6 @@ def train_one_epoch( model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, @@ -971,7 +967,6 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, - graph_compiler=graph_compiler, batch=batch, is_training=True, ) @@ -988,7 +983,7 @@ def train_one_epoch( optimizer.zero_grad() except: # noqa save_bad_model() - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch(batch, params=params) raise if params.print_diagnostics and batch_idx == 5: @@ -1077,7 +1072,6 @@ def train_one_epoch( valid_info = compute_validation_loss( params=params, model=model, - graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, ) @@ -1098,6 +1092,20 @@ def train_one_epoch( params.best_train_loss = params.train_loss +def encode_text(c: Cut, token_table: k2.SymbolTable, params: AttributeDict): + text = c.supervisions[0].text + tokens = text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors) + ids = [] + for t in tokens: + if t in token_table: + ids.append(token_table[t]) + else: + logging.warning(f"Text : {text} has OOV token : {t} , encode to ") + ids.append(token_table[""]) + c.supervisions[0].tokens = ids + return c + + def run(rank, world_size, args): """ Args: @@ -1130,14 +1138,10 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) + token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt") - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 if not params.use_transducer: params.ctc_loss_scale = 1.0 @@ -1216,17 +1220,10 @@ def run(rank, world_size, args): return True - def encode_text(c: Cut): - # Text normalize for each sample - text = c.supervisions[0].text - text = "/".join( - text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors) - ) - c.supervisions[0].text = text - return c + _encode_text = partial(encode_text, token_table=token_table, params=params) train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_cuts = train_cuts.map(encode_text) + train_cuts = train_cuts.map(_encode_text) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1240,7 +1237,7 @@ def run(rank, world_size, args): ) valid_cuts = wenetspeech.valid_cuts() - valid_cuts = valid_cuts.map(encode_text) + valid_cuts = valid_cuts.map(_encode_text) valid_dl = wenetspeech.valid_dataloaders(valid_cuts) if not params.print_diagnostics and params.scan_for_oom_batches: @@ -1248,7 +1245,6 @@ def run(rank, world_size, args): model=model, train_dl=train_dl, optimizer=optimizer, - graph_compiler=graph_compiler, params=params, ) @@ -1273,7 +1269,6 @@ def run(rank, world_size, args): model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, - graph_compiler=graph_compiler, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, @@ -1307,7 +1302,6 @@ def run(rank, world_size, args): def display_and_save_batch( batch: dict, params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, ) -> None: """Display the batch statistics and save the batch into disk. @@ -1317,8 +1311,6 @@ def display_and_save_batch( for the content in it. params: Parameters for training. See :func:`get_params`. - graph_compiler: - The compiler to encode texts to ids. """ from lhotse.utils import uuid4 @@ -1332,8 +1324,8 @@ def display_and_save_batch( logging.info(f"features shape: {features.shape}") texts = supervisions["text"] - y = graph_compiler.texts_to_ids(texts) - num_tokens = sum(len(i) for i in y) + tokens = [c.supervisions[0].tokens for c in supervisions["cut"]] + num_tokens = sum(len(i) for i in tokens) logging.info(f"num tokens: {num_tokens}") @@ -1341,7 +1333,6 @@ def scan_pessimistic_batches_for_oom( model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, params: AttributeDict, ): from lhotse.dataset import find_pessimistic_batches @@ -1357,7 +1348,6 @@ def scan_pessimistic_batches_for_oom( loss, _ = compute_loss( params=params, model=model, - graph_compiler=graph_compiler, batch=batch, is_training=True, ) @@ -1372,7 +1362,7 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch(batch, params=params) raise logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" @@ -1385,6 +1375,7 @@ def main(): args = parser.parse_args() args.lang_dir = Path(args.lang_dir) args.exp_dir = Path(args.exp_dir) + args.return_cuts = True world_size = args.world_size assert world_size >= 1 From 171cf8c9fe41666c7d70ba58c0481ec6675d8941 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 9 Apr 2025 11:52:37 +0800 Subject: [PATCH 2/4] Avoid redundant computation in PiecewiseLinear. (#1915) --- egs/librispeech/ASR/zipformer/scaling.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d345c2931..6d6281903 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -160,8 +160,10 @@ class PiecewiseLinear(object): extra_x_vals.append(extra_x_val) if len(extra_x_vals) > 0: x_vals = sorted(set(x_vals + extra_x_vals)) - y_vals1 = [self(x) for x in x_vals] - y_vals2 = [p(x) for x in x_vals] + + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + return ( PiecewiseLinear(*zip(x_vals, y_vals1)), PiecewiseLinear(*zip(x_vals, y_vals2)), From 300a821f58abdb9975a3743caf1a78953e97711b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 10 Apr 2025 10:30:37 +0800 Subject: [PATCH 3/4] Fix aishell training (#1916) --- egs/aishell/ASR/zipformer/train.py | 5 ++--- egs/aishell/ASR/zipformer/train_bbpe.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index cd253c597..dddfe52fa 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -1343,8 +1343,7 @@ def main(): run(rank=0, world_size=1, args=args) -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) main() diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py index 46a5506db..dbc262c5c 100755 --- a/egs/aishell/ASR/zipformer/train_bbpe.py +++ b/egs/aishell/ASR/zipformer/train_bbpe.py @@ -935,8 +935,7 @@ def main(): run(rank=0, world_size=1, args=args) -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) main() From 64c53640857d0b9c3fd63070c2f741d374051ce9 Mon Sep 17 00:00:00 2001 From: math345 Date: Thu, 10 Apr 2025 11:37:28 +0800 Subject: [PATCH 4/4] Fix bug: When resuming training from a checkpoint, model_avg was not assigned, resulting in a None error. (#1914) --- egs/wenetspeech/KWS/zipformer/finetune.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py index cd437da4c..249209352 100755 --- a/egs/wenetspeech/KWS/zipformer/finetune.py +++ b/egs/wenetspeech/KWS/zipformer/finetune.py @@ -593,6 +593,9 @@ def run(rank, world_size, args): if params.continue_finetune: assert params.start_epoch > 0, params.start_epoch + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) checkpoints = load_checkpoint_if_available( params=params, model=model, model_avg=model_avg )