From f686635b546baa00654f9e3caed739adf04c399e Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 30 Mar 2022 14:52:55 +0800 Subject: [PATCH 1/7] Update diagnostics (#260) * update diagnostics.py --- icefall/diagnostics.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index fa9b98fa0..08d1628ec 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -135,8 +135,13 @@ def get_diagnostics_for_dim( return "" count = sum(counts) stats = stats / count - stats, _ = torch.symeig(stats) - stats = stats.abs().sqrt() + try: + eigs, _ = torch.symeig(stats) + stats = eigs.abs().sqrt() + except: # noqa + print("Error getting eigenvalues, trying another method.") + eigs, _ = torch.eigs(stats) + stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance elif sizes_same: stats = torch.stack(stats).sum(dim=0) From 981b0640079918a43826b82acdadde68e2517bc9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 30 Mar 2022 18:50:54 +0800 Subject: [PATCH 2/7] Update doc to clarify the installation order of dependencies. (#279) --- docs/source/installation/index.rst | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index a8c3b6865..5d364dbc0 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -27,9 +27,21 @@ Installation ``icefall`` depends on `k2 `_ and `lhotse `_. -We recommend you to install ``k2`` first, as ``k2`` is bound to -a specific version of PyTorch after compilation. Install ``k2`` also -installs its dependency PyTorch, which can be reused by ``lhotse``. +We recommend you to use the following steps to install the dependencies. + +- (0) Install PyTorch and torchaudio +- (1) Install k2 +- (2) Install lhotse + +.. caution:: + + Installation order matters. + +(0) Install PyTorch and torchaudio +---------------------------------- + +Please refer ``_ to install PyTorch +and torchaudio. (1) Install k2 @@ -54,14 +66,15 @@ to install ``k2``. Please refer to ``_ to install ``lhotse``. -.. HINT:: - Install ``lhotse`` also installs its dependency `torchaudio `_. +.. hint:: -.. CAUTION:: + We strongly recommend you to use:: + + pip install git+https://github.com/lhotse-speech/lhotse + + to install the latest version of lhotse. - If you have installed ``torchaudio``, please consider uninstalling it before - installing ``lhotse``. Otherwise, it may update your already installed PyTorch. (3) Download icefall -------------------- From 2045125fd96a8c0c925f6824d90512e43ac01fb5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 31 Mar 2022 10:43:02 +0800 Subject: [PATCH 3/7] Fix CI. (#280) * Fix CI. --- .github/workflows/style_check.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 2a743705a..6b3d856df 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -45,7 +45,9 @@ jobs: - name: Install Python dependencies run: | - python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 + python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4 + # See https://github.com/psf/black/issues/2964 + # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4 - name: Run flake8 shell: bash From fc40bfea8222400ffdcb437d0d4708053a619cb2 Mon Sep 17 00:00:00 2001 From: "LIyong.Guo" <839019390@qq.com> Date: Thu, 31 Mar 2022 10:43:46 +0800 Subject: [PATCH 4/7] fix typo of torch.eig (#281) Co-authored-by: glynpu --- icefall/diagnostics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 08d1628ec..ce4ac1464 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -140,7 +140,7 @@ def get_diagnostics_for_dim( stats = eigs.abs().sqrt() except: # noqa print("Error getting eigenvalues, trying another method.") - eigs, _ = torch.eigs(stats) + eigs = torch.linalg.eigvals(stats) stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance elif sizes_same: From 9a11808ed36b57cb17cfd328f1a8537f86f468a5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 31 Mar 2022 16:48:46 +0800 Subject: [PATCH 5/7] Set the seed for dataloader. (#282) Also, suppress torch warnings about division by truncation. --- .../ASR/pruned_transducer_stateless/train.py | 7 ++++++- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 10 ++++++++++ egs/librispeech/ASR/transducer/train.py | 7 ++++++- egs/librispeech/ASR/transducer_lstm/train.py | 7 ++++++- egs/librispeech/ASR/transducer_stateless/conformer.py | 7 +++++-- egs/librispeech/ASR/transducer_stateless/train.py | 7 ++++++- .../asr_datamodule.py | 11 +++++++++++ .../ASR/transducer_stateless_multi_datasets/train.py | 7 ++++++- 8 files changed, 56 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 1f52370fd..17f82e601 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -33,6 +33,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple @@ -496,7 +497,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index a460c8eb8..8790b21e7 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -23,6 +23,7 @@ from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional +import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, @@ -34,6 +35,7 @@ from lhotse.dataset import ( SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -301,12 +303,20 @@ class LibriSpeechAsrDataModule: logging.info("Loading sampler state dict") train_sampler.load_state_dict(sampler_state_dict) + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + + def worker_init_fn(worker_id: int): + fix_random_seed(seed + worker_id) + train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, + worker_init_fn=worker_init_fn, ) return train_dl diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index a6ce79520..cbd9259e0 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -393,7 +394,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 9f06ed512..eef4d3430 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -35,6 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -397,7 +398,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index fc838f75b..488c82386 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -109,8 +109,11 @@ class Conformer(Transformer): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 2cc6480d5..d6827c17c 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -419,7 +420,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py index 669ad1d1b..2ce8d8752 100644 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py @@ -22,6 +22,7 @@ import logging from pathlib import Path from typing import Optional +import torch from lhotse import CutSet, Fbank, FbankConfig from lhotse.dataset import ( BucketingSampler, @@ -34,6 +35,7 @@ from lhotse.dataset.input_strategies import ( OnTheFlyFeatures, PrecomputedFeatures, ) +from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -253,12 +255,21 @@ class AsrDataModule: ) logging.info("About to create train dataloader") + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + + def worker_init_fn(worker_id: int): + fix_random_seed(seed + worker_id) + train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, + worker_init_fn=worker_init_fn, ) return train_dl diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py index 105f82417..5572d3f4c 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py @@ -58,6 +58,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging import random +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -466,7 +467,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() From e7493ede9069c725e083235b4bfa50bc81e5cf45 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 31 Mar 2022 20:32:00 +0800 Subject: [PATCH 6/7] Don't use a lambda for dataloader's worker_init_fn. (#284) * Don't use a lambda for dataloader's worker_init_fn. --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 12 +++++++++--- .../asr_datamodule.py | 12 +++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 8790b21e7..8dd1459ca 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -41,6 +41,14 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + class LibriSpeechAsrDataModule: """ DataModule for k2 ASR experiments. @@ -306,9 +314,7 @@ class LibriSpeechAsrDataModule: # 'seed' is derived from the current random state, which will have # previously been set in the main process. seed = torch.randint(0, 100000, ()).item() - - def worker_init_fn(worker_id: int): - fix_random_seed(seed + worker_id) + worker_init_fn = _SeedWorkers(seed) train_dl = DataLoader( train, diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py index 2ce8d8752..c6cf739fb 100644 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py @@ -41,6 +41,14 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + class AsrDataModule: def __init__(self, args: argparse.Namespace): self.args = args @@ -259,9 +267,7 @@ class AsrDataModule: # 'seed' is derived from the current random state, which will have # previously been set in the main process. seed = torch.randint(0, 100000, ()).item() - - def worker_init_fn(worker_id: int): - fix_random_seed(seed + worker_id) + worker_init_fn = _SeedWorkers(seed) train_dl = DataLoader( train, From 0b6a2213c389b2663d1adccb690a3df1f1b1f5a9 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Sat, 2 Apr 2022 15:01:45 +0800 Subject: [PATCH 7/7] Modify icefall/__init__.py. (#287) * Modify icefall/__init__.py to import common functions defined in icefall/utils.py. * Modify icefall/__init__.py and .flake8. --- .flake8 | 3 ++- icefall/__init__.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/.flake8 b/.flake8 index 229cf1d6c..dd9239b2d 100644 --- a/.flake8 +++ b/.flake8 @@ -13,4 +13,5 @@ per-file-ignores = exclude = .git, **/data/**, - icefall/shared/make_kn_lm.py + icefall/shared/make_kn_lm.py, + icefall/__init__.py diff --git a/icefall/__init__.py b/icefall/__init__.py index e69de29bb..983539d6f 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -0,0 +1,24 @@ +from .utils import ( + AttributeDict, + MetricsTracker, + add_eos, + add_sos, + concat, + encode_supervisions, + get_alignments, + get_executor, + get_texts, + l1_norm, + l2_norm, + linf_norm, + load_alignments, + make_pad_mask, + measure_gradient_norms, + measure_weight_norms, + optim_step_and_measure_param_change, + save_alignments, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +)