mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 19:12:30 +00:00
update comments
This commit is contained in:
parent
2230669129
commit
e64a6e7bec
10
egs/himia/wuw/README.md
Normal file
10
egs/himia/wuw/README.md
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
# Pretrained models and related logs/results.
|
||||||
|
|
||||||
|
## ctc tdnn baseline
|
||||||
|
|
||||||
|
AUC results for different epochs could be found at <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline>
|
||||||
|
|
||||||
|
E.g. for epoch 15 and avg 1, result log file is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_15-avg_1/log/log-auc-himia_aishell-2023-03-16-16-02-56>
|
||||||
|
|
||||||
|
Corresponding ROC curve is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_15-avg_1/himia_aishell.pdf>
|
||||||
|
|
@ -1,10 +0,0 @@
|
|||||||
# Pretrained models and releated logs/results.
|
|
||||||
|
|
||||||
## ctc tdnn baseline
|
|
||||||
|
|
||||||
Auc results for different epochs could be found at <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline>
|
|
||||||
|
|
||||||
E.g. for epoch 2 and avg 1, auc log file is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_2-avg_1/log/log-auc-himia_aishell-2023-03-16-16-02-56>
|
|
||||||
|
|
||||||
Corresponding ROC curve is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_2-avg_1/himia_aishell.pdf>
|
|
||||||
|
|
@ -61,28 +61,49 @@ class FiniteStateTransducer:
|
|||||||
"""Represents a decoding graph for wake word detection."""
|
"""Represents a decoding graph for wake word detection."""
|
||||||
|
|
||||||
def __init__(self, graph: str) -> None:
|
def __init__(self, graph: str) -> None:
|
||||||
|
"""
|
||||||
|
Construct a decoding graph in FST format given string format graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph: A string format fst. Each arc is separated by "\n".
|
||||||
|
"""
|
||||||
self.state_list = list()
|
self.state_list = list()
|
||||||
for arc_str in graph.split("\n"):
|
for arc_str in graph.split("\n"):
|
||||||
arc = arc_str.strip().split()
|
arc = arc_str.strip().split()
|
||||||
if len(arc) == 0:
|
if len(arc) == 0:
|
||||||
continue
|
continue
|
||||||
|
# An arc may contain 1, 2 or 4 elements, with format:
|
||||||
|
# src_state [dst_state] [ilabel] [olabel]
|
||||||
# 1 and 2 for final state
|
# 1 and 2 for final state
|
||||||
# 4 for non-final state
|
# 4 for non-final state
|
||||||
assert len(arc) in [1, 2, 4], f"{len(arc)} {arc_str}"
|
assert len(arc) in [1, 2, 4], f"{len(arc)} {arc_str}"
|
||||||
|
arc = [int(element) for element in arc]
|
||||||
|
src_state_id = arc[0]
|
||||||
|
max_state_id = len(self.state_list) - 1
|
||||||
if len(arc) == 4: # Non-final state
|
if len(arc) == 4: # Non-final state
|
||||||
# FST must be sorted
|
assert max_state_id <= src_state_id, (
|
||||||
if len(self.state_list) <= int(arc[0]):
|
f"Fsa must be sorted by src_state, "
|
||||||
|
f"while {cur_number_states} <= {src_state_id}. Check your graph."
|
||||||
|
)
|
||||||
|
if max_state_id < src_state_id:
|
||||||
new_state = State()
|
new_state = State()
|
||||||
self.state_list.append(new_state)
|
self.state_list.append(new_state)
|
||||||
self.state_list[int(arc[0])].add_arc(
|
|
||||||
Arc(arc[0], arc[1], arc[2], arc[3])
|
self.state_list[src_state_id].add_arc(
|
||||||
|
Arc(src_state_id, arc[1], arc[2], arc[3])
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.final_state_id = int(arc[0])
|
assert (
|
||||||
|
max_state_id == src_state_id
|
||||||
|
), f"Final state seems unreachable. Check your graph."
|
||||||
|
self.final_state_id = src_state_id
|
||||||
|
|
||||||
def to_str(self) -> None:
|
def to_str(self) -> None:
|
||||||
fst_str = ""
|
fst_str = ""
|
||||||
for state_idx in range(len(self.state_list)):
|
number_states = len(self.state_list)
|
||||||
|
if number_states == 0:
|
||||||
|
return fst_str
|
||||||
|
for state_idx in range(number_states):
|
||||||
cur_state = self.state_list[state_idx]
|
cur_state = self.state_list[state_idx]
|
||||||
for arc_idx in range(len(cur_state.arc_list)):
|
for arc_idx in range(len(cur_state.arc_list)):
|
||||||
cur_arc = cur_state.arc_list[arc_idx]
|
cur_arc = cur_state.arc_list[arc_idx]
|
||||||
|
@ -21,7 +21,7 @@ from typing import List
|
|||||||
|
|
||||||
def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]):
|
def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]):
|
||||||
"""
|
"""
|
||||||
A graph starts with blank/unknown and follwoing by wakeup word.
|
A graph starts with blank/unknown and following by wakeup word.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
wakeup_word_tokens: A sequence of token ids corresponding wakeup_word.
|
wakeup_word_tokens: A sequence of token ids corresponding wakeup_word.
|
||||||
|
@ -69,7 +69,7 @@ def get_params() -> AttributeDict:
|
|||||||
{
|
{
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"number_class": 9,
|
"num_class": 9,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
@ -150,7 +150,7 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
model = Tdnn(params.feature_dim, params.number_class)
|
model = Tdnn(params.feature_dim, params.num_class)
|
||||||
|
|
||||||
if params.avg == 1:
|
if params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo)
|
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo)
|
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
|
||||||
#
|
#
|
||||||
# See ../../LICENSE for clarification regarding multiple authors
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo)
|
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -162,7 +162,7 @@ def get_params() -> AttributeDict:
|
|||||||
- feature_dim: The model input dim. It has to match the one used
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
in computing features.
|
in computing features.
|
||||||
|
|
||||||
- number_class: Numer of classes. Each token will have a token id
|
- num_class: Number of classes. Each token will have a token id
|
||||||
from [0, num_class).
|
from [0, num_class).
|
||||||
In this recipe, 0 is usually kept for blank,
|
In this recipe, 0 is usually kept for blank,
|
||||||
and 1 is usually kept for negative words.
|
and 1 is usually kept for negative words.
|
||||||
@ -182,7 +182,7 @@ def get_params() -> AttributeDict:
|
|||||||
"valid_interval": 3000,
|
"valid_interval": 3000,
|
||||||
# parameters for model
|
# parameters for model
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"number_class": 9,
|
"num_class": 9,
|
||||||
# parameters for tokenizer
|
# parameters for tokenizer
|
||||||
"wakeup_word": "你好米雅",
|
"wakeup_word": "你好米雅",
|
||||||
"wakeup_word_tokens": [2, 3, 4, 5, 6, 3, 7, 8],
|
"wakeup_word_tokens": [2, 3, 4, 5, 6, 3, 7, 8],
|
||||||
@ -529,7 +529,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
|
||||||
model = Tdnn(params.feature_dim, params.number_class)
|
model = Tdnn(params.feature_dim, params.numb_class)
|
||||||
|
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ def load_score(score_file: Path) -> Dict[str, float]:
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
score_file: Path to score file. Each line has two columns.
|
score_file: Path to score file. Each line has two columns.
|
||||||
The first colume is utt-id, and the second one is score.
|
The first column is utt-id, and the second one is score.
|
||||||
This score could be viewed as probability of being wakeup word.
|
This score could be viewed as probability of being wakeup word.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -81,9 +81,9 @@ def get_roc_and_auc(
|
|||||||
pos_dict: scores of positive samples.
|
pos_dict: scores of positive samples.
|
||||||
neg_dict: scores of negative samples.
|
neg_dict: scores of negative samples.
|
||||||
Return:
|
Return:
|
||||||
A tuple of three elements, which will be used to plot roc curve.
|
A tuple of three elements, which will be used to plot ROC curve.
|
||||||
Refer to sklearn.metrics.roc_curve for meaning of the first and second elements.
|
Refer to sklearn.metrics.roc_curve for meaning of the first and second elements.
|
||||||
The third element is area under the roc curve(AUC).
|
The third element is area under the ROC curve(AUC).
|
||||||
"""
|
"""
|
||||||
pos_scores = np.fromiter(pos_dict.values(), dtype=float)
|
pos_scores = np.fromiter(pos_dict.values(), dtype=float)
|
||||||
neg_scores = np.fromiter(neg_dict.values(), dtype=float)
|
neg_scores = np.fromiter(neg_dict.values(), dtype=float)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo)
|
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -57,7 +57,7 @@ def get_args():
|
|||||||
"--enable-speed-perturb",
|
"--enable-speed-perturb",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="""channel of trianing set.
|
help="""channel of training set.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
@ -8,7 +8,7 @@ stop_stage=6
|
|||||||
# HI_MIA and aishell dataset are used in this experiment.
|
# HI_MIA and aishell dataset are used in this experiment.
|
||||||
# musan dataset is used for data augmentation.
|
# musan dataset is used for data augmentation.
|
||||||
#
|
#
|
||||||
# For aishell dataset downlading and preparation,
|
# For aishell dataset downloading and preparation,
|
||||||
# refer to icefall/egs/aishell/ASR/prepare.sh.
|
# refer to icefall/egs/aishell/ASR/prepare.sh.
|
||||||
#
|
#
|
||||||
# For HI_MIA and HI_MIA_CW dataset,
|
# For HI_MIA and HI_MIA_CW dataset,
|
||||||
@ -96,7 +96,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
log "Stage 1: Prepare HI_MIA and HI_MIA_CWmanifest"
|
log "Stage 1: Prepare HI_MIA and HI_MIA_CW manifest"
|
||||||
mkdir -p data/manifests
|
mkdir -p data/manifests
|
||||||
if [ ! -e data/manifests/.himia.done ]; then
|
if [ ! -e data/manifests/.himia.done ]; then
|
||||||
lhotse prepare himia $dl_dir/HiMia data/manifests
|
lhotse prepare himia $dl_dir/HiMia data/manifests
|
||||||
@ -177,8 +177,12 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
|
|
||||||
train_file=data/fbank/cuts_train_himia${train_set_channel}-aishell-shuf.jsonl.gz
|
train_file=data/fbank/cuts_train_himia${train_set_channel}-aishell-shuf.jsonl.gz
|
||||||
if [ ! -f ${train_file} ]; then
|
if [ ! -f ${train_file} ]; then
|
||||||
# SingleCutSampler is prefered for this experiment.
|
# SingleCutSampler is preferred for this experiment
|
||||||
# So `shuf` the training dataset here.
|
# rather than DynamicBucketingSampler.
|
||||||
|
# Since negative audios(Aishell) tends to be longer than positive ones(HiMia).
|
||||||
|
# if DynamicBucketingSample is used, a batch may contain either all negative sample
|
||||||
|
# or positive sample.
|
||||||
|
# So `shuf` the training dataset here and use SingleCutSampler to load data.
|
||||||
cat <(gunzip -c data/fbank/aishell_cuts_train.jsonl.gz) \
|
cat <(gunzip -c data/fbank/aishell_cuts_train.jsonl.gz) \
|
||||||
<(gunzip -c data/fbank/cuts_train${train_set_channel}.jsonl.gz) | \
|
<(gunzip -c data/fbank/cuts_train${train_set_channel}.jsonl.gz) | \
|
||||||
grep -v _sp | \
|
grep -v _sp | \
|
||||||
|
@ -26,6 +26,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|||||||
log "Stage 0: Model training"
|
log "Stage 0: Model training"
|
||||||
python ./ctc_tdnn/train.py \
|
python ./ctc_tdnn/train.py \
|
||||||
--num-epochs $epoch \
|
--num-epochs $epoch \
|
||||||
|
--exp_dir $exp_dir
|
||||||
--max-duration $max_duration
|
--max-duration $max_duration
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -34,7 +35,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
python ctc_tdnn/inference.py \
|
python ctc_tdnn/inference.py \
|
||||||
--avg $avg \
|
--avg $avg \
|
||||||
--epoch $epoch \
|
--epoch $epoch \
|
||||||
--exp-dir ${exp_dir}
|
--exp-dir $exp_dir
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
@ -45,12 +46,12 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
--score-file ${post_dir}/fst_${test_set}_score.txt
|
--score-file ${post_dir}/fst_${test_set}_score.txt
|
||||||
done
|
done
|
||||||
python ./local/auc.py \
|
python ./local/auc.py \
|
||||||
--legend himia_cw \
|
--legend himia_cw \
|
||||||
--positive-score-file ${post_dir}/fst_test_score.txt \
|
--positive-score-file ${post_dir}/fst_test_score.txt \
|
||||||
--negative-score-file ${post_dir}/fst_cw_test_score.txt
|
--negative-score-file ${post_dir}/fst_cw_test_score.txt
|
||||||
|
|
||||||
python ./local/auc.py \
|
python ./local/auc.py \
|
||||||
--legend himia_aishell \
|
--legend himia_aishell \
|
||||||
--positive-score-file ${post_dir}/fst_test_score.txt \
|
--positive-score-file ${post_dir}/fst_test_score.txt \
|
||||||
--negative-score-file ${post_dir}/fst_aishell_test_score.txt
|
--negative-score-file ${post_dir}/fst_aishell_test_score.txt
|
||||||
fi
|
fi
|
||||||
|
Loading…
x
Reference in New Issue
Block a user