From f686635b546baa00654f9e3caed739adf04c399e Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Wed, 30 Mar 2022 14:52:55 +0800 Subject: [PATCH 01/12] Update diagnostics (#260) * update diagnostics.py --- icefall/diagnostics.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index fa9b98fa0..08d1628ec 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -135,8 +135,13 @@ def get_diagnostics_for_dim( return "" count = sum(counts) stats = stats / count - stats, _ = torch.symeig(stats) - stats = stats.abs().sqrt() + try: + eigs, _ = torch.symeig(stats) + stats = eigs.abs().sqrt() + except: # noqa + print("Error getting eigenvalues, trying another method.") + eigs, _ = torch.eigs(stats) + stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance elif sizes_same: stats = torch.stack(stats).sum(dim=0) From 981b0640079918a43826b82acdadde68e2517bc9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 30 Mar 2022 18:50:54 +0800 Subject: [PATCH 02/12] Update doc to clarify the installation order of dependencies. (#279) --- docs/source/installation/index.rst | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index a8c3b6865..5d364dbc0 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -27,9 +27,21 @@ Installation ``icefall`` depends on `k2 `_ and `lhotse `_. -We recommend you to install ``k2`` first, as ``k2`` is bound to -a specific version of PyTorch after compilation. Install ``k2`` also -installs its dependency PyTorch, which can be reused by ``lhotse``. +We recommend you to use the following steps to install the dependencies. + +- (0) Install PyTorch and torchaudio +- (1) Install k2 +- (2) Install lhotse + +.. caution:: + + Installation order matters. + +(0) Install PyTorch and torchaudio +---------------------------------- + +Please refer ``_ to install PyTorch +and torchaudio. (1) Install k2 @@ -54,14 +66,15 @@ to install ``k2``. Please refer to ``_ to install ``lhotse``. -.. HINT:: - Install ``lhotse`` also installs its dependency `torchaudio `_. +.. hint:: -.. CAUTION:: + We strongly recommend you to use:: + + pip install git+https://github.com/lhotse-speech/lhotse + + to install the latest version of lhotse. - If you have installed ``torchaudio``, please consider uninstalling it before - installing ``lhotse``. Otherwise, it may update your already installed PyTorch. (3) Download icefall -------------------- From 2045125fd96a8c0c925f6824d90512e43ac01fb5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 31 Mar 2022 10:43:02 +0800 Subject: [PATCH 03/12] Fix CI. (#280) * Fix CI. --- .github/workflows/style_check.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 2a743705a..6b3d856df 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -45,7 +45,9 @@ jobs: - name: Install Python dependencies run: | - python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 + python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4 + # See https://github.com/psf/black/issues/2964 + # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4 - name: Run flake8 shell: bash From fc40bfea8222400ffdcb437d0d4708053a619cb2 Mon Sep 17 00:00:00 2001 From: "LIyong.Guo" <839019390@qq.com> Date: Thu, 31 Mar 2022 10:43:46 +0800 Subject: [PATCH 04/12] fix typo of torch.eig (#281) Co-authored-by: glynpu --- icefall/diagnostics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 08d1628ec..ce4ac1464 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -140,7 +140,7 @@ def get_diagnostics_for_dim( stats = eigs.abs().sqrt() except: # noqa print("Error getting eigenvalues, trying another method.") - eigs, _ = torch.eigs(stats) + eigs = torch.linalg.eigvals(stats) stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance elif sizes_same: From 9a11808ed36b57cb17cfd328f1a8537f86f468a5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 31 Mar 2022 16:48:46 +0800 Subject: [PATCH 05/12] Set the seed for dataloader. (#282) Also, suppress torch warnings about division by truncation. --- .../ASR/pruned_transducer_stateless/train.py | 7 ++++++- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 10 ++++++++++ egs/librispeech/ASR/transducer/train.py | 7 ++++++- egs/librispeech/ASR/transducer_lstm/train.py | 7 ++++++- egs/librispeech/ASR/transducer_stateless/conformer.py | 7 +++++-- egs/librispeech/ASR/transducer_stateless/train.py | 7 ++++++- .../asr_datamodule.py | 11 +++++++++++ .../ASR/transducer_stateless_multi_datasets/train.py | 7 ++++++- 8 files changed, 56 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 1f52370fd..17f82e601 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -33,6 +33,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple @@ -496,7 +497,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index a460c8eb8..8790b21e7 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -23,6 +23,7 @@ from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional +import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, @@ -34,6 +35,7 @@ from lhotse.dataset import ( SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -301,12 +303,20 @@ class LibriSpeechAsrDataModule: logging.info("Loading sampler state dict") train_sampler.load_state_dict(sampler_state_dict) + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + + def worker_init_fn(worker_id: int): + fix_random_seed(seed + worker_id) + train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, + worker_init_fn=worker_init_fn, ) return train_dl diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index a6ce79520..cbd9259e0 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -393,7 +394,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 9f06ed512..eef4d3430 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -35,6 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -397,7 +398,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index fc838f75b..488c82386 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -109,8 +109,11 @@ class Conformer(Transformer): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 2cc6480d5..d6827c17c 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -419,7 +420,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py index 669ad1d1b..2ce8d8752 100644 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py @@ -22,6 +22,7 @@ import logging from pathlib import Path from typing import Optional +import torch from lhotse import CutSet, Fbank, FbankConfig from lhotse.dataset import ( BucketingSampler, @@ -34,6 +35,7 @@ from lhotse.dataset.input_strategies import ( OnTheFlyFeatures, PrecomputedFeatures, ) +from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -253,12 +255,21 @@ class AsrDataModule: ) logging.info("About to create train dataloader") + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + + def worker_init_fn(worker_id: int): + fix_random_seed(seed + worker_id) + train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, + worker_init_fn=worker_init_fn, ) return train_dl diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py index 105f82417..5572d3f4c 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py @@ -58,6 +58,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging import random +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -466,7 +467,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() From e7493ede9069c725e083235b4bfa50bc81e5cf45 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 31 Mar 2022 20:32:00 +0800 Subject: [PATCH 06/12] Don't use a lambda for dataloader's worker_init_fn. (#284) * Don't use a lambda for dataloader's worker_init_fn. --- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 12 +++++++++--- .../asr_datamodule.py | 12 +++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 8790b21e7..8dd1459ca 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -41,6 +41,14 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + class LibriSpeechAsrDataModule: """ DataModule for k2 ASR experiments. @@ -306,9 +314,7 @@ class LibriSpeechAsrDataModule: # 'seed' is derived from the current random state, which will have # previously been set in the main process. seed = torch.randint(0, 100000, ()).item() - - def worker_init_fn(worker_id: int): - fix_random_seed(seed + worker_id) + worker_init_fn = _SeedWorkers(seed) train_dl = DataLoader( train, diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py index 2ce8d8752..c6cf739fb 100644 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py @@ -41,6 +41,14 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + class AsrDataModule: def __init__(self, args: argparse.Namespace): self.args = args @@ -259,9 +267,7 @@ class AsrDataModule: # 'seed' is derived from the current random state, which will have # previously been set in the main process. seed = torch.randint(0, 100000, ()).item() - - def worker_init_fn(worker_id: int): - fix_random_seed(seed + worker_id) + worker_init_fn = _SeedWorkers(seed) train_dl = DataLoader( train, From 0b6a2213c389b2663d1adccb690a3df1f1b1f5a9 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Sat, 2 Apr 2022 15:01:45 +0800 Subject: [PATCH 07/12] Modify icefall/__init__.py. (#287) * Modify icefall/__init__.py to import common functions defined in icefall/utils.py. * Modify icefall/__init__.py and .flake8. --- .flake8 | 3 ++- icefall/__init__.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/.flake8 b/.flake8 index 229cf1d6c..dd9239b2d 100644 --- a/.flake8 +++ b/.flake8 @@ -13,4 +13,5 @@ per-file-ignores = exclude = .git, **/data/**, - icefall/shared/make_kn_lm.py + icefall/shared/make_kn_lm.py, + icefall/__init__.py diff --git a/icefall/__init__.py b/icefall/__init__.py index e69de29bb..983539d6f 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -0,0 +1,24 @@ +from .utils import ( + AttributeDict, + MetricsTracker, + add_eos, + add_sos, + concat, + encode_supervisions, + get_alignments, + get_executor, + get_texts, + l1_norm, + l2_norm, + linf_norm, + load_alignments, + make_pad_mask, + measure_gradient_norms, + measure_weight_norms, + optim_step_and_measure_param_change, + save_alignments, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) From 87cf9231ea73631f1e4453400b3be06d45bcebf5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 3 Apr 2022 13:02:08 +0800 Subject: [PATCH 08/12] Support specifying iteration number of checkpoints for decoding. (#289) --- .../ASR/pruned_transducer_stateless/decode.py | 55 +++++++++++++------ icefall/checkpoint.py | 43 +++++++++++++-- 2 files changed, 76 insertions(+), 22 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 8e924bf96..49b1308b0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -98,27 +98,28 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--avg", type=int, default=15, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--avg-last-n", - type=int, - default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. - """, + "'--epoch' and '--iter'", ) parser.add_argument( @@ -453,13 +454,19 @@ def main(): ) params.res_dir = params.exp_dir / params.decoding_method - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -485,8 +492,20 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 251456c95..1ef05d964 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -216,27 +216,62 @@ def save_checkpoint_with_global_batch_idx( ) -def find_checkpoints(out_dir: Path) -> List[str]: +def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]: """Find all available checkpoints in a directory. The checkpoint filenames have the form: `checkpoint-xxx.pt` where xxx is a numerical value. + Assume you have the following checkpoints in the folder `foo`: + + - checkpoint-1.pt + - checkpoint-20.pt + - checkpoint-300.pt + - checkpoint-4000.pt + + Case 1 (Return all checkpoints):: + + find_checkpoints(out_dir='foo') + + Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e., + checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt) + + find_checkpoints(out_dir='foo', iteration=20) + + Case 3 (Return checkpoints older than checkpoint-20.pt, i.e., + checkpoint-20.pt, checkpoint-1.pt):: + + find_checkpoints(out_dir='foo', iteration=-20) + Args: out_dir: The directory where to search for checkpoints. + iteration: + If it is 0, return all available checkpoints. + If it is positive, return the checkpoints whose iteration number is + greater than or equal to `iteration`. + If it is negative, return the checkpoints whose iteration number is + less than or equal to `-iteration`. Returns: Return a list of checkpoint filenames, sorted in descending order by the numerical value in the filename. """ checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) pattern = re.compile(r"checkpoint-([0-9]+).pt") - idx_checkpoints = [ + iter_checkpoints = [ (int(pattern.search(c).group(1)), c) for c in checkpoints ] + # iter_checkpoints is a list of tuples. Each tuple contains + # two elements: (iteration_number, checkpoint-iteration_number.pt) + + iter_checkpoints = sorted( + iter_checkpoints, reverse=True, key=lambda x: x[0] + ) + if iteration >= 0: + ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration] + else: + ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration] - idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0]) - ans = [ic[1] for ic in idx_checkpoints] return ans From cb3ba16f2bf733e63c8f798b27121da93e58738b Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Tue, 5 Apr 2022 10:22:49 +0800 Subject: [PATCH 09/12] Fix aishell prepare.sh when using pre-download data (#291) --- egs/aishell/ASR/prepare.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 68f5c54d3..26324b0af 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -70,7 +70,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # |-- lexicon.txt # `-- speaker.info - if [ ! -d $dl_dir/aishell/data_aishell/wav ]; then + if [ ! -d $dl_dir/aishell/data_aishell/wav/train ]; then lhotse download aishell $dl_dir fi From ceeb95bcb8c12e5047be7f12440296bd9532c0e3 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Wed, 6 Apr 2022 11:55:29 +0800 Subject: [PATCH 10/12] update icefall/__init__.py to import more common functions. (#294) --- icefall/__init__.py | 31 +++++++++++++++++++++++++++++++ pyproject.toml | 1 + 2 files changed, 32 insertions(+) diff --git a/icefall/__init__.py b/icefall/__init__.py index 983539d6f..f466d6a62 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -1,3 +1,34 @@ +from .checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, + remove_checkpoints, + save_checkpoint, + save_checkpoint_with_global_batch_idx, +) + +from .decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) + +from .dist import ( + cleanup_dist, + setup_dist, +) + +from .env import ( + get_env_info, + get_git_branch_name, + get_git_date, + get_git_sha1, +) + from .utils import ( AttributeDict, MetricsTracker, diff --git a/pyproject.toml b/pyproject.toml index 01ff869db..ec5623f90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,6 @@ [tool.isort] profile = "black" +skip = ["icefall/__init__.py"] [tool.black] line-length = 80 From 7c0070e6f6aa5c133805c7c7f2818691cb69d34b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 8 Apr 2022 11:39:54 +0800 Subject: [PATCH 11/12] Display torch version in the training log. (#299) --- icefall/env.py | 1 + 1 file changed, 1 insertion(+) diff --git a/icefall/env.py b/icefall/env.py index 0684c4bf1..c29cbb078 100644 --- a/icefall/env.py +++ b/icefall/env.py @@ -95,6 +95,7 @@ def get_env_info() -> Dict[str, Any]: "k2-git-sha1": k2.version.__git_sha1__, "k2-git-date": k2.version.__git_date__, "lhotse-version": lhotse.__version__, + "torch-version": torch.__version__, "torch-cuda-available": torch.cuda.is_available(), "torch-cuda-version": torch.version.cuda, "python-version": sys.version[:3], From 78b8792d1d3b15008378b0e38d533a77b456bbbd Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 8 Apr 2022 13:41:33 +0800 Subject: [PATCH 12/12] Fix potential bugs in PyTorch that exist in label_smoothing. (#300) --- .../ASR/conformer_ctc/label_smoothing.py | 99 +------------------ .../ASR/conformer_mmi/label_smoothing.py | 99 +------------------ .../ASR/conformer_ctc/label_smoothing.py | 17 +++- 3 files changed, 17 insertions(+), 198 deletions(-) mode change 100644 => 120000 egs/aishell/ASR/conformer_ctc/label_smoothing.py mode change 100644 => 120000 egs/aishell/ASR/conformer_mmi/label_smoothing.py diff --git a/egs/aishell/ASR/conformer_ctc/label_smoothing.py b/egs/aishell/ASR/conformer_ctc/label_smoothing.py deleted file mode 100644 index cdc85ce9a..000000000 --- a/egs/aishell/ASR/conformer_ctc/label_smoothing.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - - -class LabelSmoothingLoss(torch.nn.Module): - """ - Implement the LabelSmoothingLoss proposed in the following paper - https://arxiv.org/pdf/1512.00567.pdf - (Rethinking the Inception Architecture for Computer Vision) - - """ - - def __init__( - self, - ignore_index: int = -1, - label_smoothing: float = 0.1, - reduction: str = "sum", - ) -> None: - """ - Args: - ignore_index: - ignored class id - label_smoothing: - smoothing rate (0.0 means the conventional cross entropy loss) - reduction: - It has the same meaning as the reduction in - `torch.nn.CrossEntropyLoss`. It can be one of the following three - values: (1) "none": No reduction will be applied. (2) "mean": the - mean of the output is taken. (3) "sum": the output will be summed. - """ - super().__init__() - assert 0.0 <= label_smoothing < 1.0 - self.ignore_index = ignore_index - self.label_smoothing = label_smoothing - self.reduction = reduction - - def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Compute loss between x and target. - - Args: - x: - prediction of dimension - (batch_size, input_length, number_of_classes). - target: - target masked with self.ignore_index of - dimension (batch_size, input_length). - - Returns: - A scalar tensor containing the loss without normalization. - """ - assert x.ndim == 3 - assert target.ndim == 2 - assert x.shape[:2] == target.shape - num_classes = x.size(-1) - x = x.reshape(-1, num_classes) - # Now x is of shape (N*T, C) - - # We don't want to change target in-place below, - # so we make a copy of it here - target = target.clone().reshape(-1) - - ignored = target == self.ignore_index - target[ignored] = 0 - - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) - - true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes - ) - # Set the value of ignored indexes to 0 - true_dist[ignored] = 0 - - loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) - if self.reduction == "sum": - return loss.sum() - elif self.reduction == "mean": - return loss.sum() / (~ignored).sum() - else: - return loss.sum(dim=-1) diff --git a/egs/aishell/ASR/conformer_ctc/label_smoothing.py b/egs/aishell/ASR/conformer_ctc/label_smoothing.py new file mode 120000 index 000000000..e9d239fff --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_mmi/label_smoothing.py b/egs/aishell/ASR/conformer_mmi/label_smoothing.py deleted file mode 100644 index cdc85ce9a..000000000 --- a/egs/aishell/ASR/conformer_mmi/label_smoothing.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - - -class LabelSmoothingLoss(torch.nn.Module): - """ - Implement the LabelSmoothingLoss proposed in the following paper - https://arxiv.org/pdf/1512.00567.pdf - (Rethinking the Inception Architecture for Computer Vision) - - """ - - def __init__( - self, - ignore_index: int = -1, - label_smoothing: float = 0.1, - reduction: str = "sum", - ) -> None: - """ - Args: - ignore_index: - ignored class id - label_smoothing: - smoothing rate (0.0 means the conventional cross entropy loss) - reduction: - It has the same meaning as the reduction in - `torch.nn.CrossEntropyLoss`. It can be one of the following three - values: (1) "none": No reduction will be applied. (2) "mean": the - mean of the output is taken. (3) "sum": the output will be summed. - """ - super().__init__() - assert 0.0 <= label_smoothing < 1.0 - self.ignore_index = ignore_index - self.label_smoothing = label_smoothing - self.reduction = reduction - - def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Compute loss between x and target. - - Args: - x: - prediction of dimension - (batch_size, input_length, number_of_classes). - target: - target masked with self.ignore_index of - dimension (batch_size, input_length). - - Returns: - A scalar tensor containing the loss without normalization. - """ - assert x.ndim == 3 - assert target.ndim == 2 - assert x.shape[:2] == target.shape - num_classes = x.size(-1) - x = x.reshape(-1, num_classes) - # Now x is of shape (N*T, C) - - # We don't want to change target in-place below, - # so we make a copy of it here - target = target.clone().reshape(-1) - - ignored = target == self.ignore_index - target[ignored] = 0 - - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) - - true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes - ) - # Set the value of ignored indexes to 0 - true_dist[ignored] = 0 - - loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) - if self.reduction == "sum": - return loss.sum() - elif self.reduction == "mean": - return loss.sum() / (~ignored).sum() - else: - return loss.sum(dim=-1) diff --git a/egs/aishell/ASR/conformer_mmi/label_smoothing.py b/egs/aishell/ASR/conformer_mmi/label_smoothing.py new file mode 120000 index 000000000..08734abd7 --- /dev/null +++ b/egs/aishell/ASR/conformer_mmi/label_smoothing.py @@ -0,0 +1 @@ +../conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index cdc85ce9a..1f2f3b137 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -76,7 +76,11 @@ class LabelSmoothingLoss(torch.nn.Module): target = target.clone().reshape(-1) ignored = target == self.ignore_index - target[ignored] = 0 + + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use target[ignored] = 0 here + target = torch.where(ignored, torch.zeros_like(target), target) true_dist = torch.nn.functional.one_hot( target, num_classes=num_classes @@ -86,8 +90,17 @@ class LabelSmoothingLoss(torch.nn.Module): true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) + # Set the value of ignored indexes to 0 - true_dist[ignored] = 0 + # + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use true_dist[ignored] = 0 here + true_dist = torch.where( + ignored.unsqueeze(1).repeat(1, true_dist.shape[1]), + torch.zeros_like(true_dist), + true_dist, + ) loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) if self.reduction == "sum":