update comments

This commit is contained in:
glynpu 2023-03-16 20:03:57 +08:00
parent 2230669129
commit e64a6e7bec
12 changed files with 67 additions and 41 deletions

10
egs/himia/wuw/README.md Normal file
View File

@ -0,0 +1,10 @@
# Pretrained models and related logs/results.
## ctc tdnn baseline
AUC results for different epochs could be found at <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline>
E.g. for epoch 15 and avg 1, result log file is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_15-avg_1/log/log-auc-himia_aishell-2023-03-16-16-02-56>
Corresponding ROC curve is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_15-avg_1/himia_aishell.pdf>

View File

@ -1,10 +0,0 @@
# Pretrained models and releated logs/results.
## ctc tdnn baseline
Auc results for different epochs could be found at <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline>
E.g. for epoch 2 and avg 1, auc log file is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_2-avg_1/log/log-auc-himia_aishell-2023-03-16-16-02-56>
Corresponding ROC curve is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_2-avg_1/himia_aishell.pdf>

View File

@ -61,28 +61,49 @@ class FiniteStateTransducer:
"""Represents a decoding graph for wake word detection.""" """Represents a decoding graph for wake word detection."""
def __init__(self, graph: str) -> None: def __init__(self, graph: str) -> None:
"""
Construct a decoding graph in FST format given string format graph.
Args:
graph: A string format fst. Each arc is separated by "\n".
"""
self.state_list = list() self.state_list = list()
for arc_str in graph.split("\n"): for arc_str in graph.split("\n"):
arc = arc_str.strip().split() arc = arc_str.strip().split()
if len(arc) == 0: if len(arc) == 0:
continue continue
# An arc may contain 1, 2 or 4 elements, with format:
# src_state [dst_state] [ilabel] [olabel]
# 1 and 2 for final state # 1 and 2 for final state
# 4 for non-final state # 4 for non-final state
assert len(arc) in [1, 2, 4], f"{len(arc)} {arc_str}" assert len(arc) in [1, 2, 4], f"{len(arc)} {arc_str}"
arc = [int(element) for element in arc]
src_state_id = arc[0]
max_state_id = len(self.state_list) - 1
if len(arc) == 4: # Non-final state if len(arc) == 4: # Non-final state
# FST must be sorted assert max_state_id <= src_state_id, (
if len(self.state_list) <= int(arc[0]): f"Fsa must be sorted by src_state, "
f"while {cur_number_states} <= {src_state_id}. Check your graph."
)
if max_state_id < src_state_id:
new_state = State() new_state = State()
self.state_list.append(new_state) self.state_list.append(new_state)
self.state_list[int(arc[0])].add_arc(
Arc(arc[0], arc[1], arc[2], arc[3]) self.state_list[src_state_id].add_arc(
Arc(src_state_id, arc[1], arc[2], arc[3])
) )
else: else:
self.final_state_id = int(arc[0]) assert (
max_state_id == src_state_id
), f"Final state seems unreachable. Check your graph."
self.final_state_id = src_state_id
def to_str(self) -> None: def to_str(self) -> None:
fst_str = "" fst_str = ""
for state_idx in range(len(self.state_list)): number_states = len(self.state_list)
if number_states == 0:
return fst_str
for state_idx in range(number_states):
cur_state = self.state_list[state_idx] cur_state = self.state_list[state_idx]
for arc_idx in range(len(cur_state.arc_list)): for arc_idx in range(len(cur_state.arc_list)):
cur_arc = cur_state.arc_list[arc_idx] cur_arc = cur_state.arc_list[arc_idx]

View File

@ -21,7 +21,7 @@ from typing import List
def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]): def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]):
""" """
A graph starts with blank/unknown and follwoing by wakeup word. A graph starts with blank/unknown and following by wakeup word.
Args: Args:
wakeup_word_tokens: A sequence of token ids corresponding wakeup_word. wakeup_word_tokens: A sequence of token ids corresponding wakeup_word.

View File

@ -69,7 +69,7 @@ def get_params() -> AttributeDict:
{ {
"env_info": get_env_info(), "env_info": get_env_info(),
"feature_dim": 80, "feature_dim": 80,
"number_class": 9, "num_class": 9,
} }
) )
return params return params
@ -150,7 +150,7 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
model = Tdnn(params.feature_dim, params.number_class) model = Tdnn(params.feature_dim, params.num_class)
if params.avg == 1: if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True)

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) # Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #

View File

@ -1,4 +1,4 @@
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) # Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
# #
# See ../../LICENSE for clarification regarding multiple authors # See ../../LICENSE for clarification regarding multiple authors
# #

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) # Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -162,7 +162,7 @@ def get_params() -> AttributeDict:
- feature_dim: The model input dim. It has to match the one used - feature_dim: The model input dim. It has to match the one used
in computing features. in computing features.
- number_class: Numer of classes. Each token will have a token id - num_class: Number of classes. Each token will have a token id
from [0, num_class). from [0, num_class).
In this recipe, 0 is usually kept for blank, In this recipe, 0 is usually kept for blank,
and 1 is usually kept for negative words. and 1 is usually kept for negative words.
@ -182,7 +182,7 @@ def get_params() -> AttributeDict:
"valid_interval": 3000, "valid_interval": 3000,
# parameters for model # parameters for model
"feature_dim": 80, "feature_dim": 80,
"number_class": 9, "num_class": 9,
# parameters for tokenizer # parameters for tokenizer
"wakeup_word": "你好米雅", "wakeup_word": "你好米雅",
"wakeup_word_tokens": [2, 3, 4, 5, 6, 3, 7, 8], "wakeup_word_tokens": [2, 3, 4, 5, 6, 3, 7, 8],
@ -529,7 +529,7 @@ def run(rank, world_size, args):
logging.info("About to create model") logging.info("About to create model")
model = Tdnn(params.feature_dim, params.number_class) model = Tdnn(params.feature_dim, params.numb_class)
checkpoints = load_checkpoint_if_available(params=params, model=model) checkpoints = load_checkpoint_if_available(params=params, model=model)

View File

@ -55,7 +55,7 @@ def load_score(score_file: Path) -> Dict[str, float]:
""" """
Args: Args:
score_file: Path to score file. Each line has two columns. score_file: Path to score file. Each line has two columns.
The first colume is utt-id, and the second one is score. The first column is utt-id, and the second one is score.
This score could be viewed as probability of being wakeup word. This score could be viewed as probability of being wakeup word.
Returns: Returns:
@ -81,9 +81,9 @@ def get_roc_and_auc(
pos_dict: scores of positive samples. pos_dict: scores of positive samples.
neg_dict: scores of negative samples. neg_dict: scores of negative samples.
Return: Return:
A tuple of three elements, which will be used to plot roc curve. A tuple of three elements, which will be used to plot ROC curve.
Refer to sklearn.metrics.roc_curve for meaning of the first and second elements. Refer to sklearn.metrics.roc_curve for meaning of the first and second elements.
The third element is area under the roc curve(AUC). The third element is area under the ROC curve(AUC).
""" """
pos_scores = np.fromiter(pos_dict.values(), dtype=float) pos_scores = np.fromiter(pos_dict.values(), dtype=float)
neg_scores = np.fromiter(neg_dict.values(), dtype=float) neg_scores = np.fromiter(neg_dict.values(), dtype=float)

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) # Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -57,7 +57,7 @@ def get_args():
"--enable-speed-perturb", "--enable-speed-perturb",
type=str2bool, type=str2bool,
default=False, default=False,
help="""channel of trianing set. help="""channel of training set.
""", """,
) )
return parser.parse_args() return parser.parse_args()

View File

@ -8,7 +8,7 @@ stop_stage=6
# HI_MIA and aishell dataset are used in this experiment. # HI_MIA and aishell dataset are used in this experiment.
# musan dataset is used for data augmentation. # musan dataset is used for data augmentation.
# #
# For aishell dataset downlading and preparation, # For aishell dataset downloading and preparation,
# refer to icefall/egs/aishell/ASR/prepare.sh. # refer to icefall/egs/aishell/ASR/prepare.sh.
# #
# For HI_MIA and HI_MIA_CW dataset, # For HI_MIA and HI_MIA_CW dataset,
@ -177,8 +177,12 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
train_file=data/fbank/cuts_train_himia${train_set_channel}-aishell-shuf.jsonl.gz train_file=data/fbank/cuts_train_himia${train_set_channel}-aishell-shuf.jsonl.gz
if [ ! -f ${train_file} ]; then if [ ! -f ${train_file} ]; then
# SingleCutSampler is prefered for this experiment. # SingleCutSampler is preferred for this experiment
# So `shuf` the training dataset here. # rather than DynamicBucketingSampler.
# Since negative audios(Aishell) tends to be longer than positive ones(HiMia).
# if DynamicBucketingSample is used, a batch may contain either all negative sample
# or positive sample.
# So `shuf` the training dataset here and use SingleCutSampler to load data.
cat <(gunzip -c data/fbank/aishell_cuts_train.jsonl.gz) \ cat <(gunzip -c data/fbank/aishell_cuts_train.jsonl.gz) \
<(gunzip -c data/fbank/cuts_train${train_set_channel}.jsonl.gz) | \ <(gunzip -c data/fbank/cuts_train${train_set_channel}.jsonl.gz) | \
grep -v _sp | \ grep -v _sp | \

View File

@ -26,6 +26,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Model training" log "Stage 0: Model training"
python ./ctc_tdnn/train.py \ python ./ctc_tdnn/train.py \
--num-epochs $epoch \ --num-epochs $epoch \
--exp_dir $exp_dir
--max-duration $max_duration --max-duration $max_duration
fi fi
@ -34,7 +35,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
python ctc_tdnn/inference.py \ python ctc_tdnn/inference.py \
--avg $avg \ --avg $avg \
--epoch $epoch \ --epoch $epoch \
--exp-dir ${exp_dir} --exp-dir $exp_dir
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then