From e64a6e7becfeeccd1e7df2abeaf7c49c7a74dd8b Mon Sep 17 00:00:00 2001 From: glynpu Date: Thu, 16 Mar 2023 20:03:57 +0800 Subject: [PATCH] update comments --- egs/himia/wuw/README.md | 10 +++++++ egs/himia/wuw/ctc_tdnn/README.md | 10 ------- egs/himia/wuw/ctc_tdnn/decode.py | 33 ++++++++++++++++++---- egs/himia/wuw/ctc_tdnn/graph.py | 2 +- egs/himia/wuw/ctc_tdnn/inference.py | 4 +-- egs/himia/wuw/ctc_tdnn/tdnn.py | 2 +- egs/himia/wuw/ctc_tdnn/tokenizer.py | 2 +- egs/himia/wuw/ctc_tdnn/train.py | 8 +++--- egs/himia/wuw/local/auc.py | 6 ++-- egs/himia/wuw/local/compute_fbank_himia.py | 4 +-- egs/himia/wuw/prepare.sh | 12 +++++--- egs/himia/wuw/run_ctc_tdnn.sh | 15 +++++----- 12 files changed, 67 insertions(+), 41 deletions(-) create mode 100644 egs/himia/wuw/README.md delete mode 100644 egs/himia/wuw/ctc_tdnn/README.md diff --git a/egs/himia/wuw/README.md b/egs/himia/wuw/README.md new file mode 100644 index 000000000..59dba046e --- /dev/null +++ b/egs/himia/wuw/README.md @@ -0,0 +1,10 @@ +# Pretrained models and related logs/results. + +## ctc tdnn baseline + +AUC results for different epochs could be found at + +E.g. for epoch 15 and avg 1, result log file is: + +Corresponding ROC curve is: + diff --git a/egs/himia/wuw/ctc_tdnn/README.md b/egs/himia/wuw/ctc_tdnn/README.md deleted file mode 100644 index 4bf30774e..000000000 --- a/egs/himia/wuw/ctc_tdnn/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Pretrained models and releated logs/results. - -## ctc tdnn baseline - -Auc results for different epochs could be found at - -E.g. for epoch 2 and avg 1, auc log file is: - -Corresponding ROC curve is: - diff --git a/egs/himia/wuw/ctc_tdnn/decode.py b/egs/himia/wuw/ctc_tdnn/decode.py index 6715c8b9c..9d05a3310 100755 --- a/egs/himia/wuw/ctc_tdnn/decode.py +++ b/egs/himia/wuw/ctc_tdnn/decode.py @@ -61,28 +61,49 @@ class FiniteStateTransducer: """Represents a decoding graph for wake word detection.""" def __init__(self, graph: str) -> None: + """ + Construct a decoding graph in FST format given string format graph. + + Args: + graph: A string format fst. Each arc is separated by "\n". + """ self.state_list = list() for arc_str in graph.split("\n"): arc = arc_str.strip().split() if len(arc) == 0: continue + # An arc may contain 1, 2 or 4 elements, with format: + # src_state [dst_state] [ilabel] [olabel] # 1 and 2 for final state # 4 for non-final state assert len(arc) in [1, 2, 4], f"{len(arc)} {arc_str}" + arc = [int(element) for element in arc] + src_state_id = arc[0] + max_state_id = len(self.state_list) - 1 if len(arc) == 4: # Non-final state - # FST must be sorted - if len(self.state_list) <= int(arc[0]): + assert max_state_id <= src_state_id, ( + f"Fsa must be sorted by src_state, " + f"while {cur_number_states} <= {src_state_id}. Check your graph." + ) + if max_state_id < src_state_id: new_state = State() self.state_list.append(new_state) - self.state_list[int(arc[0])].add_arc( - Arc(arc[0], arc[1], arc[2], arc[3]) + + self.state_list[src_state_id].add_arc( + Arc(src_state_id, arc[1], arc[2], arc[3]) ) else: - self.final_state_id = int(arc[0]) + assert ( + max_state_id == src_state_id + ), f"Final state seems unreachable. Check your graph." + self.final_state_id = src_state_id def to_str(self) -> None: fst_str = "" - for state_idx in range(len(self.state_list)): + number_states = len(self.state_list) + if number_states == 0: + return fst_str + for state_idx in range(number_states): cur_state = self.state_list[state_idx] for arc_idx in range(len(cur_state.arc_list)): cur_arc = cur_state.arc_list[arc_idx] diff --git a/egs/himia/wuw/ctc_tdnn/graph.py b/egs/himia/wuw/ctc_tdnn/graph.py index 184e01ed1..60e8afe2e 100644 --- a/egs/himia/wuw/ctc_tdnn/graph.py +++ b/egs/himia/wuw/ctc_tdnn/graph.py @@ -21,7 +21,7 @@ from typing import List def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]): """ - A graph starts with blank/unknown and follwoing by wakeup word. + A graph starts with blank/unknown and following by wakeup word. Args: wakeup_word_tokens: A sequence of token ids corresponding wakeup_word. diff --git a/egs/himia/wuw/ctc_tdnn/inference.py b/egs/himia/wuw/ctc_tdnn/inference.py index 10950cec9..b530eda62 100755 --- a/egs/himia/wuw/ctc_tdnn/inference.py +++ b/egs/himia/wuw/ctc_tdnn/inference.py @@ -69,7 +69,7 @@ def get_params() -> AttributeDict: { "env_info": get_env_info(), "feature_dim": 80, - "number_class": 9, + "num_class": 9, } ) return params @@ -150,7 +150,7 @@ def main(): logging.info(f"device: {device}") - model = Tdnn(params.feature_dim, params.number_class) + model = Tdnn(params.feature_dim, params.num_class) if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True) diff --git a/egs/himia/wuw/ctc_tdnn/tdnn.py b/egs/himia/wuw/ctc_tdnn/tdnn.py index 0f685b6c2..3425d4cca 100644 --- a/egs/himia/wuw/ctc_tdnn/tdnn.py +++ b/egs/himia/wuw/ctc_tdnn/tdnn.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) +# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo) # # See ../../../../LICENSE for clarification regarding multiple authors # diff --git a/egs/himia/wuw/ctc_tdnn/tokenizer.py b/egs/himia/wuw/ctc_tdnn/tokenizer.py index e019ebb86..5bd54d2f0 100644 --- a/egs/himia/wuw/ctc_tdnn/tokenizer.py +++ b/egs/himia/wuw/ctc_tdnn/tokenizer.py @@ -1,4 +1,4 @@ -# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) +# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo) # # See ../../LICENSE for clarification regarding multiple authors # diff --git a/egs/himia/wuw/ctc_tdnn/train.py b/egs/himia/wuw/ctc_tdnn/train.py index 61842de79..04953d9c3 100755 --- a/egs/himia/wuw/ctc_tdnn/train.py +++ b/egs/himia/wuw/ctc_tdnn/train.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) +# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -162,7 +162,7 @@ def get_params() -> AttributeDict: - feature_dim: The model input dim. It has to match the one used in computing features. - - number_class: Numer of classes. Each token will have a token id + - num_class: Number of classes. Each token will have a token id from [0, num_class). In this recipe, 0 is usually kept for blank, and 1 is usually kept for negative words. @@ -182,7 +182,7 @@ def get_params() -> AttributeDict: "valid_interval": 3000, # parameters for model "feature_dim": 80, - "number_class": 9, + "num_class": 9, # parameters for tokenizer "wakeup_word": "你好米雅", "wakeup_word_tokens": [2, 3, 4, 5, 6, 3, 7, 8], @@ -529,7 +529,7 @@ def run(rank, world_size, args): logging.info("About to create model") - model = Tdnn(params.feature_dim, params.number_class) + model = Tdnn(params.feature_dim, params.numb_class) checkpoints = load_checkpoint_if_available(params=params, model=model) diff --git a/egs/himia/wuw/local/auc.py b/egs/himia/wuw/local/auc.py index 7b35ef06b..f5a210d87 100755 --- a/egs/himia/wuw/local/auc.py +++ b/egs/himia/wuw/local/auc.py @@ -55,7 +55,7 @@ def load_score(score_file: Path) -> Dict[str, float]: """ Args: score_file: Path to score file. Each line has two columns. - The first colume is utt-id, and the second one is score. + The first column is utt-id, and the second one is score. This score could be viewed as probability of being wakeup word. Returns: @@ -81,9 +81,9 @@ def get_roc_and_auc( pos_dict: scores of positive samples. neg_dict: scores of negative samples. Return: - A tuple of three elements, which will be used to plot roc curve. + A tuple of three elements, which will be used to plot ROC curve. Refer to sklearn.metrics.roc_curve for meaning of the first and second elements. - The third element is area under the roc curve(AUC). + The third element is area under the ROC curve(AUC). """ pos_scores = np.fromiter(pos_dict.values(), dtype=float) neg_scores = np.fromiter(neg_dict.values(), dtype=float) diff --git a/egs/himia/wuw/local/compute_fbank_himia.py b/egs/himia/wuw/local/compute_fbank_himia.py index f930a8c4e..3acac8b0f 100755 --- a/egs/himia/wuw/local/compute_fbank_himia.py +++ b/egs/himia/wuw/local/compute_fbank_himia.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) +# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -57,7 +57,7 @@ def get_args(): "--enable-speed-perturb", type=str2bool, default=False, - help="""channel of trianing set. + help="""channel of training set. """, ) return parser.parse_args() diff --git a/egs/himia/wuw/prepare.sh b/egs/himia/wuw/prepare.sh index a47a20682..96df29097 100755 --- a/egs/himia/wuw/prepare.sh +++ b/egs/himia/wuw/prepare.sh @@ -8,7 +8,7 @@ stop_stage=6 # HI_MIA and aishell dataset are used in this experiment. # musan dataset is used for data augmentation. # -# For aishell dataset downlading and preparation, +# For aishell dataset downloading and preparation, # refer to icefall/egs/aishell/ASR/prepare.sh. # # For HI_MIA and HI_MIA_CW dataset, @@ -96,7 +96,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare HI_MIA and HI_MIA_CWmanifest" + log "Stage 1: Prepare HI_MIA and HI_MIA_CW manifest" mkdir -p data/manifests if [ ! -e data/manifests/.himia.done ]; then lhotse prepare himia $dl_dir/HiMia data/manifests @@ -177,8 +177,12 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then train_file=data/fbank/cuts_train_himia${train_set_channel}-aishell-shuf.jsonl.gz if [ ! -f ${train_file} ]; then - # SingleCutSampler is prefered for this experiment. - # So `shuf` the training dataset here. + # SingleCutSampler is preferred for this experiment + # rather than DynamicBucketingSampler. + # Since negative audios(Aishell) tends to be longer than positive ones(HiMia). + # if DynamicBucketingSample is used, a batch may contain either all negative sample + # or positive sample. + # So `shuf` the training dataset here and use SingleCutSampler to load data. cat <(gunzip -c data/fbank/aishell_cuts_train.jsonl.gz) \ <(gunzip -c data/fbank/cuts_train${train_set_channel}.jsonl.gz) | \ grep -v _sp | \ diff --git a/egs/himia/wuw/run_ctc_tdnn.sh b/egs/himia/wuw/run_ctc_tdnn.sh index 8a65d9c54..258c2b2b1 100644 --- a/egs/himia/wuw/run_ctc_tdnn.sh +++ b/egs/himia/wuw/run_ctc_tdnn.sh @@ -26,6 +26,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Model training" python ./ctc_tdnn/train.py \ --num-epochs $epoch \ + --exp_dir $exp_dir --max-duration $max_duration fi @@ -34,7 +35,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then python ctc_tdnn/inference.py \ --avg $avg \ --epoch $epoch \ - --exp-dir ${exp_dir} + --exp-dir $exp_dir fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then @@ -45,12 +46,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then --score-file ${post_dir}/fst_${test_set}_score.txt done python ./local/auc.py \ - --legend himia_cw \ - --positive-score-file ${post_dir}/fst_test_score.txt \ - --negative-score-file ${post_dir}/fst_cw_test_score.txt + --legend himia_cw \ + --positive-score-file ${post_dir}/fst_test_score.txt \ + --negative-score-file ${post_dir}/fst_cw_test_score.txt python ./local/auc.py \ - --legend himia_aishell \ - --positive-score-file ${post_dir}/fst_test_score.txt \ - --negative-score-file ${post_dir}/fst_aishell_test_score.txt + --legend himia_aishell \ + --positive-score-file ${post_dir}/fst_test_score.txt \ + --negative-score-file ${post_dir}/fst_aishell_test_score.txt fi