update comments

This commit is contained in:
glynpu 2023-03-16 20:03:57 +08:00
parent 2230669129
commit e64a6e7bec
12 changed files with 67 additions and 41 deletions

10
egs/himia/wuw/README.md Normal file
View File

@ -0,0 +1,10 @@
# Pretrained models and related logs/results.
## ctc tdnn baseline
AUC results for different epochs could be found at <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline>
E.g. for epoch 15 and avg 1, result log file is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_15-avg_1/log/log-auc-himia_aishell-2023-03-16-16-02-56>
Corresponding ROC curve is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_15-avg_1/himia_aishell.pdf>

View File

@ -1,10 +0,0 @@
# Pretrained models and releated logs/results.
## ctc tdnn baseline
Auc results for different epochs could be found at <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline>
E.g. for epoch 2 and avg 1, auc log file is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_2-avg_1/log/log-auc-himia_aishell-2023-03-16-16-02-56>
Corresponding ROC curve is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_2-avg_1/himia_aishell.pdf>

View File

@ -61,28 +61,49 @@ class FiniteStateTransducer:
"""Represents a decoding graph for wake word detection."""
def __init__(self, graph: str) -> None:
"""
Construct a decoding graph in FST format given string format graph.
Args:
graph: A string format fst. Each arc is separated by "\n".
"""
self.state_list = list()
for arc_str in graph.split("\n"):
arc = arc_str.strip().split()
if len(arc) == 0:
continue
# An arc may contain 1, 2 or 4 elements, with format:
# src_state [dst_state] [ilabel] [olabel]
# 1 and 2 for final state
# 4 for non-final state
assert len(arc) in [1, 2, 4], f"{len(arc)} {arc_str}"
arc = [int(element) for element in arc]
src_state_id = arc[0]
max_state_id = len(self.state_list) - 1
if len(arc) == 4: # Non-final state
# FST must be sorted
if len(self.state_list) <= int(arc[0]):
assert max_state_id <= src_state_id, (
f"Fsa must be sorted by src_state, "
f"while {cur_number_states} <= {src_state_id}. Check your graph."
)
if max_state_id < src_state_id:
new_state = State()
self.state_list.append(new_state)
self.state_list[int(arc[0])].add_arc(
Arc(arc[0], arc[1], arc[2], arc[3])
self.state_list[src_state_id].add_arc(
Arc(src_state_id, arc[1], arc[2], arc[3])
)
else:
self.final_state_id = int(arc[0])
assert (
max_state_id == src_state_id
), f"Final state seems unreachable. Check your graph."
self.final_state_id = src_state_id
def to_str(self) -> None:
fst_str = ""
for state_idx in range(len(self.state_list)):
number_states = len(self.state_list)
if number_states == 0:
return fst_str
for state_idx in range(number_states):
cur_state = self.state_list[state_idx]
for arc_idx in range(len(cur_state.arc_list)):
cur_arc = cur_state.arc_list[arc_idx]

View File

@ -21,7 +21,7 @@ from typing import List
def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]):
"""
A graph starts with blank/unknown and follwoing by wakeup word.
A graph starts with blank/unknown and following by wakeup word.
Args:
wakeup_word_tokens: A sequence of token ids corresponding wakeup_word.

View File

@ -69,7 +69,7 @@ def get_params() -> AttributeDict:
{
"env_info": get_env_info(),
"feature_dim": 80,
"number_class": 9,
"num_class": 9,
}
)
return params
@ -150,7 +150,7 @@ def main():
logging.info(f"device: {device}")
model = Tdnn(params.feature_dim, params.number_class)
model = Tdnn(params.feature_dim, params.num_class)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True)

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo)
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#

View File

@ -1,4 +1,4 @@
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo)
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
#
# See ../../LICENSE for clarification regarding multiple authors
#

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo)
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -162,7 +162,7 @@ def get_params() -> AttributeDict:
- feature_dim: The model input dim. It has to match the one used
in computing features.
- number_class: Numer of classes. Each token will have a token id
- num_class: Number of classes. Each token will have a token id
from [0, num_class).
In this recipe, 0 is usually kept for blank,
and 1 is usually kept for negative words.
@ -182,7 +182,7 @@ def get_params() -> AttributeDict:
"valid_interval": 3000,
# parameters for model
"feature_dim": 80,
"number_class": 9,
"num_class": 9,
# parameters for tokenizer
"wakeup_word": "你好米雅",
"wakeup_word_tokens": [2, 3, 4, 5, 6, 3, 7, 8],
@ -529,7 +529,7 @@ def run(rank, world_size, args):
logging.info("About to create model")
model = Tdnn(params.feature_dim, params.number_class)
model = Tdnn(params.feature_dim, params.numb_class)
checkpoints = load_checkpoint_if_available(params=params, model=model)

View File

@ -55,7 +55,7 @@ def load_score(score_file: Path) -> Dict[str, float]:
"""
Args:
score_file: Path to score file. Each line has two columns.
The first colume is utt-id, and the second one is score.
The first column is utt-id, and the second one is score.
This score could be viewed as probability of being wakeup word.
Returns:
@ -81,9 +81,9 @@ def get_roc_and_auc(
pos_dict: scores of positive samples.
neg_dict: scores of negative samples.
Return:
A tuple of three elements, which will be used to plot roc curve.
A tuple of three elements, which will be used to plot ROC curve.
Refer to sklearn.metrics.roc_curve for meaning of the first and second elements.
The third element is area under the roc curve(AUC).
The third element is area under the ROC curve(AUC).
"""
pos_scores = np.fromiter(pos_dict.values(), dtype=float)
neg_scores = np.fromiter(neg_dict.values(), dtype=float)

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo)
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -57,7 +57,7 @@ def get_args():
"--enable-speed-perturb",
type=str2bool,
default=False,
help="""channel of trianing set.
help="""channel of training set.
""",
)
return parser.parse_args()

View File

@ -8,7 +8,7 @@ stop_stage=6
# HI_MIA and aishell dataset are used in this experiment.
# musan dataset is used for data augmentation.
#
# For aishell dataset downlading and preparation,
# For aishell dataset downloading and preparation,
# refer to icefall/egs/aishell/ASR/prepare.sh.
#
# For HI_MIA and HI_MIA_CW dataset,
@ -96,7 +96,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare HI_MIA and HI_MIA_CWmanifest"
log "Stage 1: Prepare HI_MIA and HI_MIA_CW manifest"
mkdir -p data/manifests
if [ ! -e data/manifests/.himia.done ]; then
lhotse prepare himia $dl_dir/HiMia data/manifests
@ -177,8 +177,12 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
train_file=data/fbank/cuts_train_himia${train_set_channel}-aishell-shuf.jsonl.gz
if [ ! -f ${train_file} ]; then
# SingleCutSampler is prefered for this experiment.
# So `shuf` the training dataset here.
# SingleCutSampler is preferred for this experiment
# rather than DynamicBucketingSampler.
# Since negative audios(Aishell) tends to be longer than positive ones(HiMia).
# if DynamicBucketingSample is used, a batch may contain either all negative sample
# or positive sample.
# So `shuf` the training dataset here and use SingleCutSampler to load data.
cat <(gunzip -c data/fbank/aishell_cuts_train.jsonl.gz) \
<(gunzip -c data/fbank/cuts_train${train_set_channel}.jsonl.gz) | \
grep -v _sp | \

View File

@ -26,6 +26,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Model training"
python ./ctc_tdnn/train.py \
--num-epochs $epoch \
--exp_dir $exp_dir
--max-duration $max_duration
fi
@ -34,7 +35,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
python ctc_tdnn/inference.py \
--avg $avg \
--epoch $epoch \
--exp-dir ${exp_dir}
--exp-dir $exp_dir
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
@ -45,12 +46,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
--score-file ${post_dir}/fst_${test_set}_score.txt
done
python ./local/auc.py \
--legend himia_cw \
--positive-score-file ${post_dir}/fst_test_score.txt \
--negative-score-file ${post_dir}/fst_cw_test_score.txt
--legend himia_cw \
--positive-score-file ${post_dir}/fst_test_score.txt \
--negative-score-file ${post_dir}/fst_cw_test_score.txt
python ./local/auc.py \
--legend himia_aishell \
--positive-score-file ${post_dir}/fst_test_score.txt \
--negative-score-file ${post_dir}/fst_aishell_test_score.txt
--legend himia_aishell \
--positive-score-file ${post_dir}/fst_test_score.txt \
--negative-score-file ${post_dir}/fst_aishell_test_score.txt
fi