mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 11:02:29 +00:00
update comments
This commit is contained in:
parent
2230669129
commit
e64a6e7bec
10
egs/himia/wuw/README.md
Normal file
10
egs/himia/wuw/README.md
Normal file
@ -0,0 +1,10 @@
|
||||
# Pretrained models and related logs/results.
|
||||
|
||||
## ctc tdnn baseline
|
||||
|
||||
AUC results for different epochs could be found at <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline>
|
||||
|
||||
E.g. for epoch 15 and avg 1, result log file is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_15-avg_1/log/log-auc-himia_aishell-2023-03-16-16-02-56>
|
||||
|
||||
Corresponding ROC curve is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_15-avg_1/himia_aishell.pdf>
|
||||
|
@ -1,10 +0,0 @@
|
||||
# Pretrained models and releated logs/results.
|
||||
|
||||
## ctc tdnn baseline
|
||||
|
||||
Auc results for different epochs could be found at <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline>
|
||||
|
||||
E.g. for epoch 2 and avg 1, auc log file is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_2-avg_1/log/log-auc-himia_aishell-2023-03-16-16-02-56>
|
||||
|
||||
Corresponding ROC curve is: <https://huggingface.co/GuoLiyong/himia_ctc_tdnn_baseline/blob/main/exp_max_duration_100/post/epoch_2-avg_1/himia_aishell.pdf>
|
||||
|
@ -61,28 +61,49 @@ class FiniteStateTransducer:
|
||||
"""Represents a decoding graph for wake word detection."""
|
||||
|
||||
def __init__(self, graph: str) -> None:
|
||||
"""
|
||||
Construct a decoding graph in FST format given string format graph.
|
||||
|
||||
Args:
|
||||
graph: A string format fst. Each arc is separated by "\n".
|
||||
"""
|
||||
self.state_list = list()
|
||||
for arc_str in graph.split("\n"):
|
||||
arc = arc_str.strip().split()
|
||||
if len(arc) == 0:
|
||||
continue
|
||||
# An arc may contain 1, 2 or 4 elements, with format:
|
||||
# src_state [dst_state] [ilabel] [olabel]
|
||||
# 1 and 2 for final state
|
||||
# 4 for non-final state
|
||||
assert len(arc) in [1, 2, 4], f"{len(arc)} {arc_str}"
|
||||
arc = [int(element) for element in arc]
|
||||
src_state_id = arc[0]
|
||||
max_state_id = len(self.state_list) - 1
|
||||
if len(arc) == 4: # Non-final state
|
||||
# FST must be sorted
|
||||
if len(self.state_list) <= int(arc[0]):
|
||||
assert max_state_id <= src_state_id, (
|
||||
f"Fsa must be sorted by src_state, "
|
||||
f"while {cur_number_states} <= {src_state_id}. Check your graph."
|
||||
)
|
||||
if max_state_id < src_state_id:
|
||||
new_state = State()
|
||||
self.state_list.append(new_state)
|
||||
self.state_list[int(arc[0])].add_arc(
|
||||
Arc(arc[0], arc[1], arc[2], arc[3])
|
||||
|
||||
self.state_list[src_state_id].add_arc(
|
||||
Arc(src_state_id, arc[1], arc[2], arc[3])
|
||||
)
|
||||
else:
|
||||
self.final_state_id = int(arc[0])
|
||||
assert (
|
||||
max_state_id == src_state_id
|
||||
), f"Final state seems unreachable. Check your graph."
|
||||
self.final_state_id = src_state_id
|
||||
|
||||
def to_str(self) -> None:
|
||||
fst_str = ""
|
||||
for state_idx in range(len(self.state_list)):
|
||||
number_states = len(self.state_list)
|
||||
if number_states == 0:
|
||||
return fst_str
|
||||
for state_idx in range(number_states):
|
||||
cur_state = self.state_list[state_idx]
|
||||
for arc_idx in range(len(cur_state.arc_list)):
|
||||
cur_arc = cur_state.arc_list[arc_idx]
|
||||
|
@ -21,7 +21,7 @@ from typing import List
|
||||
|
||||
def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]):
|
||||
"""
|
||||
A graph starts with blank/unknown and follwoing by wakeup word.
|
||||
A graph starts with blank/unknown and following by wakeup word.
|
||||
|
||||
Args:
|
||||
wakeup_word_tokens: A sequence of token ids corresponding wakeup_word.
|
||||
|
@ -69,7 +69,7 @@ def get_params() -> AttributeDict:
|
||||
{
|
||||
"env_info": get_env_info(),
|
||||
"feature_dim": 80,
|
||||
"number_class": 9,
|
||||
"num_class": 9,
|
||||
}
|
||||
)
|
||||
return params
|
||||
@ -150,7 +150,7 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = Tdnn(params.feature_dim, params.number_class)
|
||||
model = Tdnn(params.feature_dim, params.num_class)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True)
|
||||
|
@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo)
|
||||
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo)
|
||||
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
|
@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo)
|
||||
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -162,7 +162,7 @@ def get_params() -> AttributeDict:
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
in computing features.
|
||||
|
||||
- number_class: Numer of classes. Each token will have a token id
|
||||
- num_class: Number of classes. Each token will have a token id
|
||||
from [0, num_class).
|
||||
In this recipe, 0 is usually kept for blank,
|
||||
and 1 is usually kept for negative words.
|
||||
@ -182,7 +182,7 @@ def get_params() -> AttributeDict:
|
||||
"valid_interval": 3000,
|
||||
# parameters for model
|
||||
"feature_dim": 80,
|
||||
"number_class": 9,
|
||||
"num_class": 9,
|
||||
# parameters for tokenizer
|
||||
"wakeup_word": "你好米雅",
|
||||
"wakeup_word_tokens": [2, 3, 4, 5, 6, 3, 7, 8],
|
||||
@ -529,7 +529,7 @@ def run(rank, world_size, args):
|
||||
|
||||
logging.info("About to create model")
|
||||
|
||||
model = Tdnn(params.feature_dim, params.number_class)
|
||||
model = Tdnn(params.feature_dim, params.numb_class)
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
|
@ -55,7 +55,7 @@ def load_score(score_file: Path) -> Dict[str, float]:
|
||||
"""
|
||||
Args:
|
||||
score_file: Path to score file. Each line has two columns.
|
||||
The first colume is utt-id, and the second one is score.
|
||||
The first column is utt-id, and the second one is score.
|
||||
This score could be viewed as probability of being wakeup word.
|
||||
|
||||
Returns:
|
||||
@ -81,9 +81,9 @@ def get_roc_and_auc(
|
||||
pos_dict: scores of positive samples.
|
||||
neg_dict: scores of negative samples.
|
||||
Return:
|
||||
A tuple of three elements, which will be used to plot roc curve.
|
||||
A tuple of three elements, which will be used to plot ROC curve.
|
||||
Refer to sklearn.metrics.roc_curve for meaning of the first and second elements.
|
||||
The third element is area under the roc curve(AUC).
|
||||
The third element is area under the ROC curve(AUC).
|
||||
"""
|
||||
pos_scores = np.fromiter(pos_dict.values(), dtype=float)
|
||||
neg_scores = np.fromiter(neg_dict.values(), dtype=float)
|
||||
|
@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo)
|
||||
# Copyright 2023 Xiaomi Corp. (Author: Liyong Guo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -57,7 +57,7 @@ def get_args():
|
||||
"--enable-speed-perturb",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""channel of trianing set.
|
||||
help="""channel of training set.
|
||||
""",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
@ -8,7 +8,7 @@ stop_stage=6
|
||||
# HI_MIA and aishell dataset are used in this experiment.
|
||||
# musan dataset is used for data augmentation.
|
||||
#
|
||||
# For aishell dataset downlading and preparation,
|
||||
# For aishell dataset downloading and preparation,
|
||||
# refer to icefall/egs/aishell/ASR/prepare.sh.
|
||||
#
|
||||
# For HI_MIA and HI_MIA_CW dataset,
|
||||
@ -177,8 +177,12 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
|
||||
train_file=data/fbank/cuts_train_himia${train_set_channel}-aishell-shuf.jsonl.gz
|
||||
if [ ! -f ${train_file} ]; then
|
||||
# SingleCutSampler is prefered for this experiment.
|
||||
# So `shuf` the training dataset here.
|
||||
# SingleCutSampler is preferred for this experiment
|
||||
# rather than DynamicBucketingSampler.
|
||||
# Since negative audios(Aishell) tends to be longer than positive ones(HiMia).
|
||||
# if DynamicBucketingSample is used, a batch may contain either all negative sample
|
||||
# or positive sample.
|
||||
# So `shuf` the training dataset here and use SingleCutSampler to load data.
|
||||
cat <(gunzip -c data/fbank/aishell_cuts_train.jsonl.gz) \
|
||||
<(gunzip -c data/fbank/cuts_train${train_set_channel}.jsonl.gz) | \
|
||||
grep -v _sp | \
|
||||
|
@ -26,6 +26,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Model training"
|
||||
python ./ctc_tdnn/train.py \
|
||||
--num-epochs $epoch \
|
||||
--exp_dir $exp_dir
|
||||
--max-duration $max_duration
|
||||
fi
|
||||
|
||||
@ -34,7 +35,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
python ctc_tdnn/inference.py \
|
||||
--avg $avg \
|
||||
--epoch $epoch \
|
||||
--exp-dir ${exp_dir}
|
||||
--exp-dir $exp_dir
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
|
Loading…
x
Reference in New Issue
Block a user