diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index d04e912bf..f58ba6451 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -38,14 +38,16 @@ python conformer_ctc/train.py --bucketing-sampler True \ --concatenate-cuts False \ --max-duration 200 \ --full-libri True \ - --world-size 4 + --world-size 4 \ + --lang-dir data/lang_bpe_5000 python conformer_ctc/decode.py --lattice-score-scale 0.5 \ --epoch 34 \ --avg 20 \ --method attention-decoder \ --max-duration 20 \ - --num-paths 100 + --num-paths 100 \ + --lang-dir data/lang_bpe_5000 ``` ### LibriSpeech training results (Tdnn-Lstm) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 85161f737..c6a6dd85d 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -42,6 +42,7 @@ from icefall.decode import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + get_env_info, get_texts, setup_logger, store_transcripts, @@ -128,6 +129,13 @@ def get_parser(): """, ) + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_5000", + help="lang directory", + ) + return parser @@ -135,7 +143,6 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), "feature_dim": 80, "nhead": 8, @@ -151,6 +158,7 @@ def get_params() -> AttributeDict: "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, + "env_info": get_env_info(), } ) return params diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 95029fadb..574fafcfe 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -34,7 +34,7 @@ from icefall.decode import ( rescore_with_attention_decoder, rescore_with_whole_lattice, ) -from icefall.utils import AttributeDict, get_texts +from icefall.utils import AttributeDict, get_env_info, get_texts def get_parser(): @@ -224,6 +224,7 @@ def main(): params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index b0dbe72ad..e3242c943 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -43,6 +43,7 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, encode_supervisions, + get_env_info, setup_logger, str2bool, ) @@ -74,6 +75,13 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_5000", + help="lang directory", + ) + parser.add_argument( "--num-epochs", type=int, @@ -108,9 +116,6 @@ def get_params() -> AttributeDict: - exp_dir: It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved - - lang_dir: It contains language related input files such as - "lexicon.txt" - - lr: It specifies the initial learning rate - feature_dim: The model input dim. It has to match the one used @@ -151,7 +156,6 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_bpe"), "feature_dim": 80, "weight_decay": 1e-6, "subsampling_factor": 4, @@ -160,7 +164,7 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 10, + "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, "beam_size": 10, @@ -176,6 +180,7 @@ def get_params() -> AttributeDict: "use_feat_batchnorm": True, "lr_factor": 5.0, "warm_step": 80000, + "env_info": get_env_info(), } ) diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index f06e013f6..3a68e0f23 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -41,6 +41,8 @@ dl_dir=$PWD/download # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( 5000 + 2000 + 1000 ) # All files generated by this script are saved in "data". @@ -190,5 +192,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then ./local/compile_hlg.py --lang-dir $lang_dir done fi - -cd data && ln -sfv lang_bpe_5000 lang_bpe diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 23b2e794c..4dda7818d 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -39,6 +39,7 @@ from icefall.decode import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + get_env_info, get_texts, setup_logger, store_transcripts, @@ -103,6 +104,7 @@ def get_params() -> AttributeDict: # "method": "nbest", # num_paths is used when method is "nbest" and "nbest-rescoring" "num_paths": 100, + "env_info": get_env_info(), } ) return params diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index 4f82a989c..523f36e3e 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -34,7 +34,7 @@ from icefall.decode import ( one_best_decoding, rescore_with_whole_lattice, ) -from icefall.utils import AttributeDict, get_texts +from icefall.utils import AttributeDict, get_env_info, get_texts def get_parser(): @@ -159,6 +159,7 @@ def main(): params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 4d45d197b..6144f4a54 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -44,6 +44,7 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, encode_supervisions, + get_env_info, setup_logger, str2bool, ) @@ -168,6 +169,7 @@ def get_params() -> AttributeDict: "beam_size": 10, "reduction": "sum", "use_double_scores": True, + "env_info": get_env_info(), } ) diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index 54fdbb3cc..62d8bb9d7 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -17,6 +17,7 @@ from icefall.decode import get_lattice, one_best_decoding from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + get_env_info, get_texts, setup_logger, store_transcripts, @@ -256,6 +257,7 @@ def main(): params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() setup_logger(f"{params.exp_dir}/log/log-decode") logging.info("Decoding started") diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index fb92110e3..5b85008a6 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -29,7 +29,7 @@ from model import Tdnn from torch.nn.utils.rnn import pad_sequence from icefall.decode import get_lattice, one_best_decoding -from icefall.utils import AttributeDict, get_texts +from icefall.utils import AttributeDict, get_env_info, get_texts def get_parser(): @@ -116,6 +116,7 @@ def main(): params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index 39c5ef3ef..f2e986688 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -24,7 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, setup_logger, str2bool +from icefall.utils import AttributeDict, get_env_info, setup_logger, str2bool def get_parser(): @@ -483,6 +483,7 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) + params["env_info"] = get_env_info() fix_random_seed(42) if world_size > 1: diff --git a/icefall/utils.py b/icefall/utils.py index 1016bcd35..ad08e4d8f 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -19,14 +19,17 @@ import argparse import logging import os import subprocess +import sys from collections import defaultdict from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, TextIO, Tuple, Union +from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union import k2 +import k2.version import kaldialign +import lhotse import torch import torch.distributed as dist @@ -132,17 +135,82 @@ def setup_logger( logging.getLogger("").addHandler(console) -def get_env_info(): - """ - TODO: - """ +def get_git_sha1(): + git_commit = ( + subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + dirty_commit = ( + len( + subprocess.run( + ["git", "diff", "--shortstat"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + > 0 + ) + git_commit = ( + git_commit + "-dirty" if dirty_commit else git_commit + "-clean" + ) + return git_commit + + +def get_git_date(): + git_date = ( + subprocess.run( + ["git", "log", "-1", "--format=%ad", "--date=local"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + return git_date + + +def get_git_branch_name(): + git_date = ( + subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + check=True, + stdout=subprocess.PIPE, + ) + .stdout.decode() + .rstrip("\n") + .strip() + ) + return git_date + + +def get_env_info() -> Dict[str, Any]: + """Get the environment information.""" return { - "k2-git-sha1": None, - "k2-version": None, - "lhotse-version": None, - "torch-version": None, - "icefall-sha1": None, - "icefall-version": None, + "k2-version": k2.version.__version__, + "k2-build-type": k2.version.__build_type__, + "k2-with-cuda": k2.with_cuda, + "k2-git-sha1": k2.version.__git_sha1__, + "k2-git-date": k2.version.__git_date__, + "lhotse-version": lhotse.__version__, + "torch-cuda-available": torch.cuda.is_available(), + "torch-cuda-version": torch.version.cuda, + "python-version": sys.version[:3], + "icefall-git-branch": get_git_branch_name(), + "icefall-git-sha1": get_git_sha1(), + "icefall-git-date": get_git_date(), + "icefall-path": str(Path(__file__).resolve().parent.parent), + "k2-path": str(Path(k2.__file__).resolve()), + "lhotse-path": str(Path(lhotse.__file__).resolve()), } diff --git a/test/test_utils.py b/test/test_utils.py index b4c9358fd..8b0c03e95 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -20,7 +20,12 @@ import k2 import pytest import torch -from icefall.utils import AttributeDict, encode_supervisions, get_texts +from icefall.utils import ( + AttributeDict, + encode_supervisions, + get_env_info, + get_texts, +) @pytest.fixture @@ -108,3 +113,8 @@ def test_attribute_dict(): assert s["b"] == 20 s.c = 100 assert s["c"] == 100 + + +def test_get_env_info(): + s = get_env_info() + print(s)