mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 19:12:30 +00:00
update logging information
This commit is contained in:
parent
93a168ab06
commit
27428187d0
@ -21,12 +21,16 @@ import copy
|
||||
import logging
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from typing import Tuple
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from lhotse.features.io import NumpyHdf5Reader
|
||||
from tqdm import tqdm
|
||||
|
||||
from icefall.utils import AttributeDict
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
)
|
||||
|
||||
from train import get_params
|
||||
from graph import ctc_trivial_decoding_graph
|
||||
@ -242,26 +246,33 @@ def get_parser():
|
||||
description="A simple FST decoder for the wake word detection\n"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoding-graph", help="decoding graph", default="himia_ctc_graph.txt"
|
||||
"--post-h5",
|
||||
type=str,
|
||||
help="model output in h5 format",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--score-file",
|
||||
type=str,
|
||||
help="file to save scores of each utterance",
|
||||
)
|
||||
parser.add_argument("--post-h5", help="model output in h5 format")
|
||||
parser.add_argument("--score-file", help="file to save scores of each utterance")
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s"
|
||||
)
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
post_dir = Path(params.post_h5).parent
|
||||
test_set = Path(params.post_h5).stem
|
||||
setup_logger(f"{post_dir}/log/log-decode-{test_set}")
|
||||
|
||||
keys = NumpyHdf5Reader(params.post_h5).hdf.keys()
|
||||
graph = FiniteStateTransducer(ctc_trivial_decoding_graph(params.wakeup_word_tokens))
|
||||
|
||||
logging.info(f"Graph used:\n{graph.to_str()}")
|
||||
logging.info("About to load data to decoder.")
|
||||
|
||||
logging.info(f"About to load {test_set}.")
|
||||
keys = NumpyHdf5Reader(params.post_h5).hdf.keys()
|
||||
with ProcessPoolExecutor() as executor, open(
|
||||
params.score_file, "w", encoding="utf8"
|
||||
) as fout:
|
||||
@ -269,11 +280,13 @@ def main():
|
||||
executor.submit(decode_utt, params, key, params.post_h5, graph)
|
||||
for key in tqdm(keys)
|
||||
]
|
||||
logging.info("Decoding.")
|
||||
logging.info(f"Decoding {test_set}.")
|
||||
for future in tqdm(futures):
|
||||
k, v = future.result()
|
||||
fout.write(str(k) + " " + str(v) + "\n")
|
||||
|
||||
logging.info(f"Finish decoding {test_set}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -140,7 +140,7 @@ def main():
|
||||
out_dir = f"{params.exp_dir}/post/epoch_{params.epoch}-avg_{params.avg}/"
|
||||
Path(out_dir).mkdir(parents=True, exist_ok=True)
|
||||
params.out_dir = out_dir
|
||||
setup_logger(f"{out_dir}/log-decode")
|
||||
setup_logger(f"{out_dir}/log/log-inference")
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
@ -186,6 +186,7 @@ def main():
|
||||
test_dls = [aishell_test_dl, test_dl, cw_test_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dls):
|
||||
logging.info(f"About to inference {test_set}")
|
||||
inference_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
@ -193,6 +194,8 @@ def main():
|
||||
test_set=test_set,
|
||||
)
|
||||
|
||||
logging.info(f"finish inferencing {test_set}")
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
|
@ -64,7 +64,9 @@ class WakeupWordTokenizer(object):
|
||||
self.negative_word_tokens = [1]
|
||||
self.negative_number_tokens = 1
|
||||
|
||||
def texts_to_token_ids(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||
def texts_to_token_ids(
|
||||
self, texts: List[str]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""Convert a list of texts to a list of k2.Fsa based texts.
|
||||
|
||||
Args:
|
||||
|
@ -564,7 +564,7 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
fix_random_seed(params.seed + epoch)
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
|
||||
|
@ -25,16 +25,29 @@ import numpy as np
|
||||
from pathlib import Path
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
|
||||
from icefall.utils import setup_logger
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--positive-score-file", required=True, help="score file of positive data"
|
||||
"--positive-score-file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="score file of positive data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--negative-score-file", required=True, help="score file of negative data"
|
||||
"--negative-score-file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="score file of negative data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--legend",
|
||||
type=str,
|
||||
required=True,
|
||||
help="legend of ROC curve picture.",
|
||||
)
|
||||
parser.add_argument("--legend", required=True, help="utt2dur file of negative data")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -88,10 +101,13 @@ def get_roc_and_auc(
|
||||
|
||||
|
||||
def main():
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
args = get_args()
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
score_dir = Path(args.positive_score_file).parent
|
||||
setup_logger(f"{score_dir}/log/log-auc-{args.legend}")
|
||||
logging.info(f"About to compute AUC of {args.legend}")
|
||||
|
||||
pos_dict = load_score(args.positive_score_file)
|
||||
neg_dict = load_score(args.negative_score_file)
|
||||
fpr, tpr, roc_auc = get_roc_and_auc(pos_dict, neg_dict)
|
||||
|
@ -39,7 +39,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Decode and compute area under curve(AUC)"
|
||||
for test_set in test aishell_test cw_test; do
|
||||
python ctc_tdnn/decode.py \
|
||||
--decoding-graph ./data/LG.int \
|
||||
--post-h5 ${post_dir}/${test_set}.h5 \
|
||||
--score-file ${post_dir}/fst_${test_set}_pos_h5.txt
|
||||
done
|
||||
|
Loading…
x
Reference in New Issue
Block a user