mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 03:22:19 +00:00
update logging information
This commit is contained in:
parent
93a168ab06
commit
27428187d0
@ -21,12 +21,16 @@ import copy
|
|||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from lhotse.features.io import NumpyHdf5Reader
|
from lhotse.features.io import NumpyHdf5Reader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from icefall.utils import AttributeDict
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
setup_logger,
|
||||||
|
)
|
||||||
|
|
||||||
from train import get_params
|
from train import get_params
|
||||||
from graph import ctc_trivial_decoding_graph
|
from graph import ctc_trivial_decoding_graph
|
||||||
@ -242,26 +246,33 @@ def get_parser():
|
|||||||
description="A simple FST decoder for the wake word detection\n"
|
description="A simple FST decoder for the wake word detection\n"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoding-graph", help="decoding graph", default="himia_ctc_graph.txt"
|
"--post-h5",
|
||||||
|
type=str,
|
||||||
|
help="model output in h5 format",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--score-file",
|
||||||
|
type=str,
|
||||||
|
help="file to save scores of each utterance",
|
||||||
)
|
)
|
||||||
parser.add_argument("--post-h5", help="model output in h5 format")
|
|
||||||
parser.add_argument("--score-file", help="file to save scores of each utterance")
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s"
|
|
||||||
)
|
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
post_dir = Path(params.post_h5).parent
|
||||||
|
test_set = Path(params.post_h5).stem
|
||||||
|
setup_logger(f"{post_dir}/log/log-decode-{test_set}")
|
||||||
|
|
||||||
keys = NumpyHdf5Reader(params.post_h5).hdf.keys()
|
|
||||||
graph = FiniteStateTransducer(ctc_trivial_decoding_graph(params.wakeup_word_tokens))
|
graph = FiniteStateTransducer(ctc_trivial_decoding_graph(params.wakeup_word_tokens))
|
||||||
|
|
||||||
logging.info(f"Graph used:\n{graph.to_str()}")
|
logging.info(f"Graph used:\n{graph.to_str()}")
|
||||||
logging.info("About to load data to decoder.")
|
|
||||||
|
logging.info(f"About to load {test_set}.")
|
||||||
|
keys = NumpyHdf5Reader(params.post_h5).hdf.keys()
|
||||||
with ProcessPoolExecutor() as executor, open(
|
with ProcessPoolExecutor() as executor, open(
|
||||||
params.score_file, "w", encoding="utf8"
|
params.score_file, "w", encoding="utf8"
|
||||||
) as fout:
|
) as fout:
|
||||||
@ -269,11 +280,13 @@ def main():
|
|||||||
executor.submit(decode_utt, params, key, params.post_h5, graph)
|
executor.submit(decode_utt, params, key, params.post_h5, graph)
|
||||||
for key in tqdm(keys)
|
for key in tqdm(keys)
|
||||||
]
|
]
|
||||||
logging.info("Decoding.")
|
logging.info(f"Decoding {test_set}.")
|
||||||
for future in tqdm(futures):
|
for future in tqdm(futures):
|
||||||
k, v = future.result()
|
k, v = future.result()
|
||||||
fout.write(str(k) + " " + str(v) + "\n")
|
fout.write(str(k) + " " + str(v) + "\n")
|
||||||
|
|
||||||
|
logging.info(f"Finish decoding {test_set}.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -140,7 +140,7 @@ def main():
|
|||||||
out_dir = f"{params.exp_dir}/post/epoch_{params.epoch}-avg_{params.avg}/"
|
out_dir = f"{params.exp_dir}/post/epoch_{params.epoch}-avg_{params.avg}/"
|
||||||
Path(out_dir).mkdir(parents=True, exist_ok=True)
|
Path(out_dir).mkdir(parents=True, exist_ok=True)
|
||||||
params.out_dir = out_dir
|
params.out_dir = out_dir
|
||||||
setup_logger(f"{out_dir}/log-decode")
|
setup_logger(f"{out_dir}/log/log-inference")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -186,6 +186,7 @@ def main():
|
|||||||
test_dls = [aishell_test_dl, test_dl, cw_test_dl]
|
test_dls = [aishell_test_dl, test_dl, cw_test_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dls):
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
|
logging.info(f"About to inference {test_set}")
|
||||||
inference_dataset(
|
inference_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
@ -193,6 +194,8 @@ def main():
|
|||||||
test_set=test_set,
|
test_set=test_set,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging.info(f"finish inferencing {test_set}")
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,7 +64,9 @@ class WakeupWordTokenizer(object):
|
|||||||
self.negative_word_tokens = [1]
|
self.negative_word_tokens = [1]
|
||||||
self.negative_number_tokens = 1
|
self.negative_number_tokens = 1
|
||||||
|
|
||||||
def texts_to_token_ids(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
def texts_to_token_ids(
|
||||||
|
self, texts: List[str]
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||||
"""Convert a list of texts to a list of k2.Fsa based texts.
|
"""Convert a list of texts to a list of k2.Fsa based texts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -564,7 +564,7 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
fix_random_seed(params.seed + epoch)
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
@ -25,16 +25,29 @@ import numpy as np
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from sklearn.metrics import roc_curve, auc
|
from sklearn.metrics import roc_curve, auc
|
||||||
|
|
||||||
|
from icefall.utils import setup_logger
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--positive-score-file", required=True, help="score file of positive data"
|
"--positive-score-file",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="score file of positive data",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--negative-score-file", required=True, help="score file of negative data"
|
"--negative-score-file",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="score file of negative data",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--legend",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="legend of ROC curve picture.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--legend", required=True, help="utt2dur file of negative data")
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -88,10 +101,13 @@ def get_roc_and_auc(
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
|
score_dir = Path(args.positive_score_file).parent
|
||||||
|
setup_logger(f"{score_dir}/log/log-auc-{args.legend}")
|
||||||
|
logging.info(f"About to compute AUC of {args.legend}")
|
||||||
|
|
||||||
pos_dict = load_score(args.positive_score_file)
|
pos_dict = load_score(args.positive_score_file)
|
||||||
neg_dict = load_score(args.negative_score_file)
|
neg_dict = load_score(args.negative_score_file)
|
||||||
fpr, tpr, roc_auc = get_roc_and_auc(pos_dict, neg_dict)
|
fpr, tpr, roc_auc = get_roc_and_auc(pos_dict, neg_dict)
|
||||||
|
@ -39,7 +39,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
log "Stage 2: Decode and compute area under curve(AUC)"
|
log "Stage 2: Decode and compute area under curve(AUC)"
|
||||||
for test_set in test aishell_test cw_test; do
|
for test_set in test aishell_test cw_test; do
|
||||||
python ctc_tdnn/decode.py \
|
python ctc_tdnn/decode.py \
|
||||||
--decoding-graph ./data/LG.int \
|
|
||||||
--post-h5 ${post_dir}/${test_set}.h5 \
|
--post-h5 ${post_dir}/${test_set}.h5 \
|
||||||
--score-file ${post_dir}/fst_${test_set}_pos_h5.txt
|
--score-file ${post_dir}/fst_${test_set}_pos_h5.txt
|
||||||
done
|
done
|
||||||
|
Loading…
x
Reference in New Issue
Block a user