From 23913f6afdea59caf703e3ac715852810cd246ad Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 31 Oct 2023 10:28:20 +0800 Subject: [PATCH] Minor refinements for some stale but recently merged PRs (#1354) * incorporate https://github.com/k2-fsa/icefall/pull/1269 * incorporate https://github.com/k2-fsa/icefall/pull/1301 * black formatted * incorporate https://github.com/k2-fsa/icefall/pull/1162 * black formatted --- egs/aishell/ASR/zipformer/train.py | 2 +- .../ASR/zipformer/asr_datamodule.py | 2 +- egs/gigaspeech/ASR/zipformer/train.py | 2 +- .../ASR/zipformer_prompt_asr/optim.py | 12 +++---- .../zipformer_prompt_asr/train_baseline.py | 2 +- .../train_bert_encoder.py | 2 +- .../ASR/tiny_transducer_ctc/asr_datamodule.py | 8 ++--- .../ASR/tiny_transducer_ctc/ctc_decode.py | 3 +- .../ASR/tiny_transducer_ctc/encoder.py | 1 - .../ASR/tiny_transducer_ctc/export.py | 31 ++++++----------- .../ASR/tiny_transducer_ctc/train.py | 4 +-- egs/librispeech/ASR/zipformer_ctc/export.py | 21 +++++------- egs/librispeech/ASR/zipformer_ctc/train.py | 2 +- icefall/diagnostics.py | 33 ++++++++++++------- 14 files changed, 57 insertions(+), 68 deletions(-) diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index 7e7b02829..d381649e4 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -1128,7 +1128,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py index c4472ed23..6adfdbfbb 100644 --- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py +++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py @@ -233,7 +233,7 @@ class GigaSpeechAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index d8ff4fecc..d93cc221c 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -1164,7 +1164,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py b/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py index a767761eb..159e363c7 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/optim.py @@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer): yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for (stacked_params, _state, _names), batch in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer): size_update_period=4, clipping_update_period=100, ): - defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -327,9 +326,7 @@ class ScaledAdam(BatchedOptimizer): batch = True for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -429,7 +426,7 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: + for p, state, param_names in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -514,7 +511,7 @@ class ScaledAdam(BatchedOptimizer): from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: + for p, state, batch_param_names in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars @@ -530,7 +527,6 @@ class ScaledAdam(BatchedOptimizer): for name, sumsq_orig, rms, grad in zip( batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad ): - proportion_orig = sumsq_orig / tot_sumsq all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) @@ -1106,7 +1102,7 @@ def _test_scaled_adam(hidden_dim: int): # if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 + # 512 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py index 32302602c..c8b20d021 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py @@ -1194,7 +1194,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py index e253d1118..9822b99c1 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py @@ -1565,7 +1565,7 @@ def run(rank, world_size, args): if params.print_diagnostics: args.max_duration = 100 opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py index 8facb6dba..3acd22ae4 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py @@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -225,7 +225,7 @@ class LibriSpeechAsrDataModule: logging.info("About to get Musan cuts") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) ) else: logging.info("Disable MUSAN") @@ -307,8 +307,8 @@ class LibriSpeechAsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py index 402aeac0c..cda03b56e 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py @@ -22,10 +22,11 @@ import argparse import logging import math +import pprint from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple -import pprint + import k2 import sentencepiece as spm import torch diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py b/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py index 4c7fca4fc..afdd00293 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py @@ -253,7 +253,6 @@ class CausalSqueezeExcite1d(nn.Module): return y def forward(self, x: Tensor) -> Tensor: - assert len(x.shape) == 3, "Input is not a 3D tensor!" y = self.exponential_moving_avg(x) y = y.permute(0, 2, 1) # make channel last for squeeze op diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/export.py b/egs/librispeech/ASR/tiny_transducer_ctc/export.py index 4117f7244..334dd011e 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/export.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/export.py @@ -76,17 +76,17 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch +from train import add_model_arguments, get_params, get_transducer_model + from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) -from icefall.lexicon import UniqLexicon -from icefall.utils import str2bool -from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -143,13 +143,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -189,17 +186,9 @@ def main(): logging.info(f"device: {device}") - if "lang_bpe" in str(params.lang_dir): - sp = spm.SentencePieceProcessor() - sp.load(params.lang_dir + "/bpe.model") - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - else: - assert "lang_phone" in str(params.lang_dir) - phone_lexicon = UniqLexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(phone_lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py index 307ad72aa..8920764cd 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -89,7 +89,6 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( "--encoder-dim", type=int, @@ -405,7 +404,6 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conv1dNet( output_dim=params.encoder_dim, input_dim=params.feature_dim, @@ -1043,7 +1041,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/zipformer_ctc/export.py b/egs/librispeech/ASR/zipformer_ctc/export.py index 0ff50f128..4c46aea2c 100755 --- a/egs/librispeech/ASR/zipformer_ctc/export.py +++ b/egs/librispeech/ASR/zipformer_ctc/export.py @@ -23,6 +23,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_ctc_model, get_params @@ -33,8 +34,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -90,11 +90,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="""It contains language related input files such as "lexicon.txt" - """, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -113,17 +112,15 @@ def get_parser(): def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) - logging.info(params) + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - params.vocab_size = num_classes + logging.info(params) device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py index f40344357..60990456d 100755 --- a/egs/librispeech/ASR/zipformer_ctc/train.py +++ b/egs/librispeech/ASR/zipformer_ctc/train.py @@ -947,7 +947,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index ebf61784e..65b6f67b0 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -244,16 +244,14 @@ class TensorDiagnostic(object): if stats_type == "eigs": try: - if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eigh'): + if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"): eigs, _ = torch.linalg.eigh(stats) else: eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() except: # noqa - print( - "Error getting eigenvalues, trying another method." - ) - if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'eig'): + print("Error getting eigenvalues, trying another method.") + if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"): eigs, _ = torch.linalg.eig(stats) eigs = eigs.abs() else: @@ -579,10 +577,15 @@ def attach_diagnostics( ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, - class_name=get_class_name(_module)) - + if isinstance(o, Tensor) and o.dtype in ( + torch.float32, + torch.float16, + torch.float64, + ): + _model_diagnostic[f"{_name}.output[{i}]"].accumulate( + o, class_name=get_class_name(_module) + ) + def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] @@ -596,9 +599,15 @@ def attach_diagnostics( ) elif isinstance(_output, tuple): for i, o in enumerate(_output): - if isinstance(o, Tensor) and o.dtype in ( torch.float32, torch.float16, torch.float64 ): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, - class_name=get_class_name(_module)) + if isinstance(o, Tensor) and o.dtype in ( + torch.float32, + torch.float16, + torch.float64, + ): + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate( + o, class_name=get_class_name(_module) + ) + module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook)