update logging information

This commit is contained in:
glynpu 2023-03-16 14:50:24 +08:00
parent 93a168ab06
commit 27428187d0
6 changed files with 52 additions and 19 deletions

View File

@ -21,12 +21,16 @@ import copy
import logging
from concurrent.futures import ProcessPoolExecutor
from typing import Tuple
from pathlib import Path
import numpy as np
from lhotse.features.io import NumpyHdf5Reader
from tqdm import tqdm
from icefall.utils import AttributeDict
from icefall.utils import (
AttributeDict,
setup_logger,
)
from train import get_params
from graph import ctc_trivial_decoding_graph
@ -242,26 +246,33 @@ def get_parser():
description="A simple FST decoder for the wake word detection\n"
)
parser.add_argument(
"--decoding-graph", help="decoding graph", default="himia_ctc_graph.txt"
"--post-h5",
type=str,
help="model output in h5 format",
)
parser.add_argument(
"--score-file",
type=str,
help="file to save scores of each utterance",
)
parser.add_argument("--post-h5", help="model output in h5 format")
parser.add_argument("--score-file", help="file to save scores of each utterance")
return parser
def main():
logging.basicConfig(
level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s"
)
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
post_dir = Path(params.post_h5).parent
test_set = Path(params.post_h5).stem
setup_logger(f"{post_dir}/log/log-decode-{test_set}")
keys = NumpyHdf5Reader(params.post_h5).hdf.keys()
graph = FiniteStateTransducer(ctc_trivial_decoding_graph(params.wakeup_word_tokens))
logging.info(f"Graph used:\n{graph.to_str()}")
logging.info("About to load data to decoder.")
logging.info(f"About to load {test_set}.")
keys = NumpyHdf5Reader(params.post_h5).hdf.keys()
with ProcessPoolExecutor() as executor, open(
params.score_file, "w", encoding="utf8"
) as fout:
@ -269,11 +280,13 @@ def main():
executor.submit(decode_utt, params, key, params.post_h5, graph)
for key in tqdm(keys)
]
logging.info("Decoding.")
logging.info(f"Decoding {test_set}.")
for future in tqdm(futures):
k, v = future.result()
fout.write(str(k) + " " + str(v) + "\n")
logging.info(f"Finish decoding {test_set}.")
if __name__ == "__main__":
main()

View File

@ -140,7 +140,7 @@ def main():
out_dir = f"{params.exp_dir}/post/epoch_{params.epoch}-avg_{params.avg}/"
Path(out_dir).mkdir(parents=True, exist_ok=True)
params.out_dir = out_dir
setup_logger(f"{out_dir}/log-decode")
setup_logger(f"{out_dir}/log/log-inference")
logging.info("Decoding started")
logging.info(params)
@ -186,6 +186,7 @@ def main():
test_dls = [aishell_test_dl, test_dl, cw_test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
logging.info(f"About to inference {test_set}")
inference_dataset(
dl=test_dl,
params=params,
@ -193,6 +194,8 @@ def main():
test_set=test_set,
)
logging.info(f"finish inferencing {test_set}")
logging.info("Done!")

View File

@ -64,7 +64,9 @@ class WakeupWordTokenizer(object):
self.negative_word_tokens = [1]
self.negative_number_tokens = 1
def texts_to_token_ids(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, int]:
def texts_to_token_ids(
self, texts: List[str]
) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""Convert a list of texts to a list of k2.Fsa based texts.
Args:

View File

@ -564,7 +564,7 @@ def run(rank, world_size, args):
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs):
for epoch in range(params.start_epoch, params.num_epochs + 1):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)

View File

@ -25,16 +25,29 @@ import numpy as np
from pathlib import Path
from sklearn.metrics import roc_curve, auc
from icefall.utils import setup_logger
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--positive-score-file", required=True, help="score file of positive data"
"--positive-score-file",
type=str,
required=True,
help="score file of positive data",
)
parser.add_argument(
"--negative-score-file", required=True, help="score file of negative data"
"--negative-score-file",
type=str,
required=True,
help="score file of negative data",
)
parser.add_argument(
"--legend",
type=str,
required=True,
help="legend of ROC curve picture.",
)
parser.add_argument("--legend", required=True, help="utt2dur file of negative data")
return parser.parse_args()
@ -88,10 +101,13 @@ def get_roc_and_auc(
def main():
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
args = get_args()
logging.basicConfig(format=formatter, level=logging.INFO)
score_dir = Path(args.positive_score_file).parent
setup_logger(f"{score_dir}/log/log-auc-{args.legend}")
logging.info(f"About to compute AUC of {args.legend}")
pos_dict = load_score(args.positive_score_file)
neg_dict = load_score(args.negative_score_file)
fpr, tpr, roc_auc = get_roc_and_auc(pos_dict, neg_dict)

View File

@ -39,7 +39,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Decode and compute area under curve(AUC)"
for test_set in test aishell_test cw_test; do
python ctc_tdnn/decode.py \
--decoding-graph ./data/LG.int \
--post-h5 ${post_dir}/${test_set}.h5 \
--score-file ${post_dir}/fst_${test_set}_pos_h5.txt
done