diff --git a/egs/himia/wuw/ctc_tdnn/decode.py b/egs/himia/wuw/ctc_tdnn/decode.py index 7acc7d595..6715c8b9c 100755 --- a/egs/himia/wuw/ctc_tdnn/decode.py +++ b/egs/himia/wuw/ctc_tdnn/decode.py @@ -21,12 +21,16 @@ import copy import logging from concurrent.futures import ProcessPoolExecutor from typing import Tuple +from pathlib import Path import numpy as np from lhotse.features.io import NumpyHdf5Reader from tqdm import tqdm -from icefall.utils import AttributeDict +from icefall.utils import ( + AttributeDict, + setup_logger, +) from train import get_params from graph import ctc_trivial_decoding_graph @@ -242,26 +246,33 @@ def get_parser(): description="A simple FST decoder for the wake word detection\n" ) parser.add_argument( - "--decoding-graph", help="decoding graph", default="himia_ctc_graph.txt" + "--post-h5", + type=str, + help="model output in h5 format", + ) + parser.add_argument( + "--score-file", + type=str, + help="file to save scores of each utterance", ) - parser.add_argument("--post-h5", help="model output in h5 format") - parser.add_argument("--score-file", help="file to save scores of each utterance") return parser def main(): - logging.basicConfig( - level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s" - ) parser = get_parser() args = parser.parse_args() params = get_params() params.update(vars(args)) + post_dir = Path(params.post_h5).parent + test_set = Path(params.post_h5).stem + setup_logger(f"{post_dir}/log/log-decode-{test_set}") - keys = NumpyHdf5Reader(params.post_h5).hdf.keys() graph = FiniteStateTransducer(ctc_trivial_decoding_graph(params.wakeup_word_tokens)) + logging.info(f"Graph used:\n{graph.to_str()}") - logging.info("About to load data to decoder.") + + logging.info(f"About to load {test_set}.") + keys = NumpyHdf5Reader(params.post_h5).hdf.keys() with ProcessPoolExecutor() as executor, open( params.score_file, "w", encoding="utf8" ) as fout: @@ -269,11 +280,13 @@ def main(): executor.submit(decode_utt, params, key, params.post_h5, graph) for key in tqdm(keys) ] - logging.info("Decoding.") + logging.info(f"Decoding {test_set}.") for future in tqdm(futures): k, v = future.result() fout.write(str(k) + " " + str(v) + "\n") + logging.info(f"Finish decoding {test_set}.") + if __name__ == "__main__": main() diff --git a/egs/himia/wuw/ctc_tdnn/inference.py b/egs/himia/wuw/ctc_tdnn/inference.py index eae9c5333..10950cec9 100755 --- a/egs/himia/wuw/ctc_tdnn/inference.py +++ b/egs/himia/wuw/ctc_tdnn/inference.py @@ -140,7 +140,7 @@ def main(): out_dir = f"{params.exp_dir}/post/epoch_{params.epoch}-avg_{params.avg}/" Path(out_dir).mkdir(parents=True, exist_ok=True) params.out_dir = out_dir - setup_logger(f"{out_dir}/log-decode") + setup_logger(f"{out_dir}/log/log-inference") logging.info("Decoding started") logging.info(params) @@ -186,6 +186,7 @@ def main(): test_dls = [aishell_test_dl, test_dl, cw_test_dl] for test_set, test_dl in zip(test_sets, test_dls): + logging.info(f"About to inference {test_set}") inference_dataset( dl=test_dl, params=params, @@ -193,6 +194,8 @@ def main(): test_set=test_set, ) + logging.info(f"finish inferencing {test_set}") + logging.info("Done!") diff --git a/egs/himia/wuw/ctc_tdnn/tokenizer.py b/egs/himia/wuw/ctc_tdnn/tokenizer.py index bc207ec04..e019ebb86 100644 --- a/egs/himia/wuw/ctc_tdnn/tokenizer.py +++ b/egs/himia/wuw/ctc_tdnn/tokenizer.py @@ -64,7 +64,9 @@ class WakeupWordTokenizer(object): self.negative_word_tokens = [1] self.negative_number_tokens = 1 - def texts_to_token_ids(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, int]: + def texts_to_token_ids( + self, texts: List[str] + ) -> Tuple[torch.Tensor, torch.Tensor, int]: """Convert a list of texts to a list of k2.Fsa based texts. Args: diff --git a/egs/himia/wuw/ctc_tdnn/train.py b/egs/himia/wuw/ctc_tdnn/train.py index 249821c29..fd9d42cad 100755 --- a/egs/himia/wuw/ctc_tdnn/train.py +++ b/egs/himia/wuw/ctc_tdnn/train.py @@ -564,7 +564,7 @@ def run(rank, world_size, args): params=params, ) - for epoch in range(params.start_epoch, params.num_epochs): + for epoch in range(params.start_epoch, params.num_epochs + 1): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) diff --git a/egs/himia/wuw/local/auc.py b/egs/himia/wuw/local/auc.py index d7357a0f1..7b35ef06b 100755 --- a/egs/himia/wuw/local/auc.py +++ b/egs/himia/wuw/local/auc.py @@ -25,16 +25,29 @@ import numpy as np from pathlib import Path from sklearn.metrics import roc_curve, auc +from icefall.utils import setup_logger + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--positive-score-file", required=True, help="score file of positive data" + "--positive-score-file", + type=str, + required=True, + help="score file of positive data", ) parser.add_argument( - "--negative-score-file", required=True, help="score file of negative data" + "--negative-score-file", + type=str, + required=True, + help="score file of negative data", + ) + parser.add_argument( + "--legend", + type=str, + required=True, + help="legend of ROC curve picture.", ) - parser.add_argument("--legend", required=True, help="utt2dur file of negative data") return parser.parse_args() @@ -88,10 +101,13 @@ def get_roc_and_auc( def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" args = get_args() - logging.basicConfig(format=formatter, level=logging.INFO) + + score_dir = Path(args.positive_score_file).parent + setup_logger(f"{score_dir}/log/log-auc-{args.legend}") + logging.info(f"About to compute AUC of {args.legend}") + pos_dict = load_score(args.positive_score_file) neg_dict = load_score(args.negative_score_file) fpr, tpr, roc_auc = get_roc_and_auc(pos_dict, neg_dict) diff --git a/egs/himia/wuw/run_ctc_tdnn.sh b/egs/himia/wuw/run_ctc_tdnn.sh index 16ecacf6a..7dfaeefea 100644 --- a/egs/himia/wuw/run_ctc_tdnn.sh +++ b/egs/himia/wuw/run_ctc_tdnn.sh @@ -39,7 +39,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Decode and compute area under curve(AUC)" for test_set in test aishell_test cw_test; do python ctc_tdnn/decode.py \ - --decoding-graph ./data/LG.int \ --post-h5 ${post_dir}/${test_set}.h5 \ --score-file ${post_dir}/fst_${test_set}_pos_h5.txt done