mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Print environment information.
Print information about k2, lhotse, PyTorch, and icefall.
This commit is contained in:
parent
5072e28afb
commit
8f64fb9921
@ -38,14 +38,16 @@ python conformer_ctc/train.py --bucketing-sampler True \
|
|||||||
--concatenate-cuts False \
|
--concatenate-cuts False \
|
||||||
--max-duration 200 \
|
--max-duration 200 \
|
||||||
--full-libri True \
|
--full-libri True \
|
||||||
--world-size 4
|
--world-size 4 \
|
||||||
|
--lang-dir data/lang_bpe_5000
|
||||||
|
|
||||||
python conformer_ctc/decode.py --lattice-score-scale 0.5 \
|
python conformer_ctc/decode.py --lattice-score-scale 0.5 \
|
||||||
--epoch 34 \
|
--epoch 34 \
|
||||||
--avg 20 \
|
--avg 20 \
|
||||||
--method attention-decoder \
|
--method attention-decoder \
|
||||||
--max-duration 20 \
|
--max-duration 20 \
|
||||||
--num-paths 100
|
--num-paths 100 \
|
||||||
|
--lang-dir data/lang_bpe_5000
|
||||||
```
|
```
|
||||||
|
|
||||||
### LibriSpeech training results (Tdnn-Lstm)
|
### LibriSpeech training results (Tdnn-Lstm)
|
||||||
|
@ -42,6 +42,7 @@ from icefall.decode import (
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
|
get_env_info,
|
||||||
get_texts,
|
get_texts,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
@ -128,6 +129,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_5000",
|
||||||
|
help="lang directory",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -135,7 +143,6 @@ def get_params() -> AttributeDict:
|
|||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"exp_dir": Path("conformer_ctc/exp"),
|
"exp_dir": Path("conformer_ctc/exp"),
|
||||||
"lang_dir": Path("data/lang_bpe"),
|
|
||||||
"lm_dir": Path("data/lm"),
|
"lm_dir": Path("data/lm"),
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
@ -151,6 +158,7 @@ def get_params() -> AttributeDict:
|
|||||||
"min_active_states": 30,
|
"min_active_states": 30,
|
||||||
"max_active_states": 10000,
|
"max_active_states": 10000,
|
||||||
"use_double_scores": True,
|
"use_double_scores": True,
|
||||||
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
|
@ -34,7 +34,7 @@ from icefall.decode import (
|
|||||||
rescore_with_attention_decoder,
|
rescore_with_attention_decoder,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
from icefall.utils import AttributeDict, get_texts
|
from icefall.utils import AttributeDict, get_env_info, get_texts
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -224,6 +224,7 @@ def main():
|
|||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
params["env_info"] = get_env_info()
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
@ -43,6 +43,7 @@ from icefall.lexicon import Lexicon
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
|
get_env_info,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
@ -74,6 +75,13 @@ def get_parser():
|
|||||||
help="Should various information be logged in tensorboard.",
|
help="Should various information be logged in tensorboard.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_5000",
|
||||||
|
help="lang directory",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs",
|
"--num-epochs",
|
||||||
type=int,
|
type=int,
|
||||||
@ -108,9 +116,6 @@ def get_params() -> AttributeDict:
|
|||||||
- exp_dir: It specifies the directory where all training related
|
- exp_dir: It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
|
||||||
- lang_dir: It contains language related input files such as
|
|
||||||
"lexicon.txt"
|
|
||||||
|
|
||||||
- lr: It specifies the initial learning rate
|
- lr: It specifies the initial learning rate
|
||||||
|
|
||||||
- feature_dim: The model input dim. It has to match the one used
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
@ -151,7 +156,6 @@ def get_params() -> AttributeDict:
|
|||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"exp_dir": Path("conformer_ctc/exp"),
|
"exp_dir": Path("conformer_ctc/exp"),
|
||||||
"lang_dir": Path("data/lang_bpe"),
|
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"weight_decay": 1e-6,
|
"weight_decay": 1e-6,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
@ -160,7 +164,7 @@ def get_params() -> AttributeDict:
|
|||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 10,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000,
|
"valid_interval": 3000,
|
||||||
"beam_size": 10,
|
"beam_size": 10,
|
||||||
@ -176,6 +180,7 @@ def get_params() -> AttributeDict:
|
|||||||
"use_feat_batchnorm": True,
|
"use_feat_batchnorm": True,
|
||||||
"lr_factor": 5.0,
|
"lr_factor": 5.0,
|
||||||
"warm_step": 80000,
|
"warm_step": 80000,
|
||||||
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -41,6 +41,8 @@ dl_dir=$PWD/download
|
|||||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||||
vocab_sizes=(
|
vocab_sizes=(
|
||||||
5000
|
5000
|
||||||
|
2000
|
||||||
|
1000
|
||||||
)
|
)
|
||||||
|
|
||||||
# All files generated by this script are saved in "data".
|
# All files generated by this script are saved in "data".
|
||||||
@ -190,5 +192,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
|||||||
./local/compile_hlg.py --lang-dir $lang_dir
|
./local/compile_hlg.py --lang-dir $lang_dir
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
cd data && ln -sfv lang_bpe_5000 lang_bpe
|
|
||||||
|
@ -39,6 +39,7 @@ from icefall.decode import (
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
|
get_env_info,
|
||||||
get_texts,
|
get_texts,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
@ -103,6 +104,7 @@ def get_params() -> AttributeDict:
|
|||||||
# "method": "nbest",
|
# "method": "nbest",
|
||||||
# num_paths is used when method is "nbest" and "nbest-rescoring"
|
# num_paths is used when method is "nbest" and "nbest-rescoring"
|
||||||
"num_paths": 100,
|
"num_paths": 100,
|
||||||
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
|
@ -34,7 +34,7 @@ from icefall.decode import (
|
|||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
from icefall.utils import AttributeDict, get_texts
|
from icefall.utils import AttributeDict, get_env_info, get_texts
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -159,6 +159,7 @@ def main():
|
|||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
params["env_info"] = get_env_info()
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
@ -44,6 +44,7 @@ from icefall.lexicon import Lexicon
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
encode_supervisions,
|
encode_supervisions,
|
||||||
|
get_env_info,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
@ -168,6 +169,7 @@ def get_params() -> AttributeDict:
|
|||||||
"beam_size": 10,
|
"beam_size": 10,
|
||||||
"reduction": "sum",
|
"reduction": "sum",
|
||||||
"use_double_scores": True,
|
"use_double_scores": True,
|
||||||
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ from icefall.decode import get_lattice, one_best_decoding
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
|
get_env_info,
|
||||||
get_texts,
|
get_texts,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
@ -256,6 +257,7 @@ def main():
|
|||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
params["env_info"] = get_env_info()
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log/log-decode")
|
setup_logger(f"{params.exp_dir}/log/log-decode")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
@ -29,7 +29,7 @@ from model import Tdnn
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall.decode import get_lattice, one_best_decoding
|
from icefall.decode import get_lattice, one_best_decoding
|
||||||
from icefall.utils import AttributeDict, get_texts
|
from icefall.utils import AttributeDict, get_env_info, get_texts
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -116,6 +116,7 @@ def main():
|
|||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
params["env_info"] = get_env_info()
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
@ -24,7 +24,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
from icefall.utils import AttributeDict, get_env_info, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -483,6 +483,7 @@ def run(rank, world_size, args):
|
|||||||
"""
|
"""
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
params["env_info"] = get_env_info()
|
||||||
|
|
||||||
fix_random_seed(42)
|
fix_random_seed(42)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
|
@ -19,14 +19,17 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
from typing import Any, Dict, Iterable, List, TextIO, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
import k2.version
|
||||||
import kaldialign
|
import kaldialign
|
||||||
|
import lhotse
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
@ -132,17 +135,82 @@ def setup_logger(
|
|||||||
logging.getLogger("").addHandler(console)
|
logging.getLogger("").addHandler(console)
|
||||||
|
|
||||||
|
|
||||||
def get_env_info():
|
def get_git_sha1():
|
||||||
"""
|
git_commit = (
|
||||||
TODO:
|
subprocess.run(
|
||||||
"""
|
["git", "rev-parse", "--short", "HEAD"],
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
.stdout.decode()
|
||||||
|
.rstrip("\n")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
dirty_commit = (
|
||||||
|
len(
|
||||||
|
subprocess.run(
|
||||||
|
["git", "diff", "--shortstat"],
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
.stdout.decode()
|
||||||
|
.rstrip("\n")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
> 0
|
||||||
|
)
|
||||||
|
git_commit = (
|
||||||
|
git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
|
||||||
|
)
|
||||||
|
return git_commit
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_date():
|
||||||
|
git_date = (
|
||||||
|
subprocess.run(
|
||||||
|
["git", "log", "-1", "--format=%ad", "--date=local"],
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
.stdout.decode()
|
||||||
|
.rstrip("\n")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
return git_date
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_branch_name():
|
||||||
|
git_date = (
|
||||||
|
subprocess.run(
|
||||||
|
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
.stdout.decode()
|
||||||
|
.rstrip("\n")
|
||||||
|
.strip()
|
||||||
|
)
|
||||||
|
return git_date
|
||||||
|
|
||||||
|
|
||||||
|
def get_env_info() -> Dict[str, Any]:
|
||||||
|
"""Get the environment information."""
|
||||||
return {
|
return {
|
||||||
"k2-git-sha1": None,
|
"k2-version": k2.version.__version__,
|
||||||
"k2-version": None,
|
"k2-build-type": k2.version.__build_type__,
|
||||||
"lhotse-version": None,
|
"k2-with-cuda": k2.with_cuda,
|
||||||
"torch-version": None,
|
"k2-git-sha1": k2.version.__git_sha1__,
|
||||||
"icefall-sha1": None,
|
"k2-git-date": k2.version.__git_date__,
|
||||||
"icefall-version": None,
|
"lhotse-version": lhotse.__version__,
|
||||||
|
"torch-cuda-available": torch.cuda.is_available(),
|
||||||
|
"torch-cuda-version": torch.version.cuda,
|
||||||
|
"python-version": sys.version[:3],
|
||||||
|
"icefall-git-branch": get_git_branch_name(),
|
||||||
|
"icefall-git-sha1": get_git_sha1(),
|
||||||
|
"icefall-git-date": get_git_date(),
|
||||||
|
"icefall-path": str(Path(__file__).resolve().parent.parent),
|
||||||
|
"k2-path": str(Path(k2.__file__).resolve()),
|
||||||
|
"lhotse-path": str(Path(lhotse.__file__).resolve()),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +20,12 @@ import k2
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from icefall.utils import AttributeDict, encode_supervisions, get_texts
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
encode_supervisions,
|
||||||
|
get_env_info,
|
||||||
|
get_texts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -108,3 +113,8 @@ def test_attribute_dict():
|
|||||||
assert s["b"] == 20
|
assert s["b"] == 20
|
||||||
s.c = 100
|
s.c = 100
|
||||||
assert s["c"] == 100
|
assert s["c"] == 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_env_info():
|
||||||
|
s = get_env_info()
|
||||||
|
print(s)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user