mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
Fix wewetspeech prepare.sh
This commit is contained in:
parent
afe3b183c4
commit
7d91e8b6d5
@ -49,6 +49,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|||||||
pushd data
|
pushd data
|
||||||
ln -svf $(realpath ../ASR/data/lang_bpe_500) .
|
ln -svf $(realpath ../ASR/data/lang_bpe_500) .
|
||||||
popd
|
popd
|
||||||
|
touch data/fbank/.gigaspeech.done
|
||||||
else
|
else
|
||||||
log "Gigaspeech dataset already exists, skipping."
|
log "Gigaspeech dataset already exists, skipping."
|
||||||
fi
|
fi
|
||||||
@ -63,7 +64,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|||||||
ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt
|
ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt
|
||||||
ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt
|
ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt
|
||||||
pushd open-commands
|
pushd open-commands
|
||||||
./script/prepare.sh --stage 3 --stop-stage 3
|
./script/prepare.sh --stage 2 --stop-stage 2
|
||||||
./script/prepare.sh --stage 6 --stop-stage 6
|
./script/prepare.sh --stage 6 --stop-stage 6
|
||||||
popd
|
popd
|
||||||
popd
|
popd
|
||||||
|
@ -186,13 +186,6 @@ def get_parser():
|
|||||||
help="The default threshold (probability) to trigger the keyword.",
|
help="The default threshold (probability) to trigger the keyword.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--keywords-version",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="The keywords configuration version, just to save results to different files.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-tailing-blanks",
|
"--num-tailing-blanks",
|
||||||
type=int,
|
type=int,
|
||||||
@ -222,7 +215,7 @@ def decode_one_batch(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
kws_graph: Optional[ContextGraph] = None,
|
keywords_graph: Optional[ContextGraph] = None,
|
||||||
) -> List[List[Tuple[str, Tuple[int, int]]]]:
|
) -> List[List[Tuple[str, Tuple[int, int]]]]:
|
||||||
"""Decode one batch and return the result in a list.
|
"""Decode one batch and return the result in a list.
|
||||||
|
|
||||||
@ -242,7 +235,7 @@ def decode_one_batch(
|
|||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
kws_graph:
|
keywords_graph:
|
||||||
The graph containing keywords.
|
The graph containing keywords.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
@ -274,7 +267,7 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
keywords_graph=kws_graph,
|
keywords_graph=keywords_graph,
|
||||||
beam=params.beam,
|
beam=params.beam,
|
||||||
num_tailing_blanks=params.num_tailing_blanks,
|
num_tailing_blanks=params.num_tailing_blanks,
|
||||||
blank_penalty=params.blank_penalty,
|
blank_penalty=params.blank_penalty,
|
||||||
@ -295,7 +288,7 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
kws_graph: ContextGraph,
|
keywords_graph: ContextGraph,
|
||||||
keywords: Set[str],
|
keywords: Set[str],
|
||||||
test_only_keywords: bool,
|
test_only_keywords: bool,
|
||||||
) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]:
|
) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]:
|
||||||
@ -310,7 +303,7 @@ def decode_dataset(
|
|||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
sp:
|
||||||
The BPE model.
|
The BPE model.
|
||||||
kws_graph:
|
keywords_graph:
|
||||||
The graph containing keywords.
|
The graph containing keywords.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
@ -341,7 +334,7 @@ def decode_dataset(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
kws_graph=kws_graph,
|
keywords_graph=keywords_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -459,7 +452,7 @@ def save_results(
|
|||||||
s += f"\tPrecision: {precision:.3f}\n"
|
s += f"\tPrecision: {precision:.3f}\n"
|
||||||
s += f"\tRecall(PPR): {recall:.3f}\n"
|
s += f"\tRecall(PPR): {recall:.3f}\n"
|
||||||
s += f"\tFPR: {fpr:.3f}\n"
|
s += f"\tFPR: {fpr:.3f}\n"
|
||||||
s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n"
|
s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n"
|
||||||
if key != "all":
|
if key != "all":
|
||||||
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
|
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
|
||||||
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
|
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
|
||||||
@ -505,7 +498,7 @@ def main():
|
|||||||
params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
|
params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
|
||||||
if params.blank_penalty != 0:
|
if params.blank_penalty != 0:
|
||||||
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||||
params.suffix += f"-version-{params.keywords_version}"
|
params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
@ -555,10 +548,10 @@ def main():
|
|||||||
|
|
||||||
params.keywords_config = "".join(keywords_config)
|
params.keywords_config = "".join(keywords_config)
|
||||||
|
|
||||||
kws_graph = ContextGraph(
|
keywords_graph = ContextGraph(
|
||||||
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
|
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
|
||||||
)
|
)
|
||||||
kws_graph.build(
|
keywords_graph.build(
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
phrases=phrases,
|
phrases=phrases,
|
||||||
scores=keywords_scores,
|
scores=keywords_scores,
|
||||||
@ -677,7 +670,7 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
kws_graph=kws_graph,
|
keywords_graph=keywords_graph,
|
||||||
keywords=keywords,
|
keywords=keywords,
|
||||||
test_only_keywords="fsc" in test_set,
|
test_only_keywords="fsc" in test_set,
|
||||||
)
|
)
|
||||||
|
@ -276,7 +276,7 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
def add_training_arguments(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size",
|
"--world-size",
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -993,7 +993,7 @@ def keywords_search(
|
|||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == 3, encoder_out.shape
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
assert context_graph is not None
|
assert keywords_graph is not None
|
||||||
|
|
||||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
input=encoder_out,
|
input=encoder_out,
|
||||||
@ -1018,7 +1018,7 @@ def keywords_search(
|
|||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[-1] * (context_size - 1) + [blank_id],
|
ys=[-1] * (context_size - 1) + [blank_id],
|
||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
context_state=context_graph.root,
|
context_state=keywords_graph.root,
|
||||||
timestamp=[],
|
timestamp=[],
|
||||||
ac_probs=[],
|
ac_probs=[],
|
||||||
)
|
)
|
||||||
@ -1125,7 +1125,7 @@ def keywords_search(
|
|||||||
context_score,
|
context_score,
|
||||||
new_context_state,
|
new_context_state,
|
||||||
_,
|
_,
|
||||||
) = context_graph.forward_one_step(hyp.context_state, new_token)
|
) = keywords_graph.forward_one_step(hyp.context_state, new_token)
|
||||||
new_num_tailing_blanks = 0
|
new_num_tailing_blanks = 0
|
||||||
if new_context_state.token == -1: # root
|
if new_context_state.token == -1: # root
|
||||||
new_ys[-context_size:] = [-1] * (context_size - 1) + [blank_id]
|
new_ys[-context_size:] = [-1] * (context_size - 1) + [blank_id]
|
||||||
@ -1143,7 +1143,7 @@ def keywords_search(
|
|||||||
B[i].add(new_hyp)
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
top_hyp = B[i].get_most_probable(length_norm=True)
|
top_hyp = B[i].get_most_probable(length_norm=True)
|
||||||
matched, matched_state = context_graph.is_matched(top_hyp.context_state)
|
matched, matched_state = keywords_graph.is_matched(top_hyp.context_state)
|
||||||
if matched:
|
if matched:
|
||||||
ac_prob = (
|
ac_prob = (
|
||||||
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
|
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
|
||||||
@ -1164,7 +1164,7 @@ def keywords_search(
|
|||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[-1] * (context_size - 1) + [blank_id],
|
ys=[-1] * (context_size - 1) + [blank_id],
|
||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
context_state=context_graph.root,
|
context_state=keywords_graph.root,
|
||||||
timestamp=[],
|
timestamp=[],
|
||||||
ac_probs=[],
|
ac_probs=[],
|
||||||
)
|
)
|
||||||
@ -1174,7 +1174,7 @@ def keywords_search(
|
|||||||
|
|
||||||
for i, hyps in enumerate(B):
|
for i, hyps in enumerate(B):
|
||||||
top_hyp = hyps.get_most_probable(length_norm=True)
|
top_hyp = hyps.get_most_probable(length_norm=True)
|
||||||
matched, matched_state = context_graph.is_matched(top_hyp.context_state)
|
matched, matched_state = keywords_graph.is_matched(top_hyp.context_state)
|
||||||
if matched:
|
if matched:
|
||||||
ac_prob = (
|
ac_prob = (
|
||||||
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
|
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
|
||||||
|
@ -376,5 +376,6 @@ if [ $stage -le 22 ] && [ $stop_stage -ge 22 ]; then
|
|||||||
--token-type $token \
|
--token-type $token \
|
||||||
--lang-dir $lang_dir
|
--lang-dir $lang_dir
|
||||||
fi
|
fi
|
||||||
|
python ./local/compile_lg.py --lang-dir $lang_dir
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
@ -22,63 +22,69 @@ log() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
log "Stage 0: Prepare gigaspeech dataset."
|
log "Stage 0: Prepare wewetspeech dataset."
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
if [ ! -e data/fbank/.gigaspeech.done ]; then
|
if [ ! -e data/fbank/.wewetspeech.done ]; then
|
||||||
pushd ../ASR
|
pushd ../ASR
|
||||||
./prepare.sh --stage 0 --stop-stage 9
|
./prepare.sh --stage 0 --stop-stage 17
|
||||||
./prepare.sh --stage 11 --stop-stage 11
|
./prepare.sh --stage 22 --stop-stage 22
|
||||||
popd
|
popd
|
||||||
pushd data/fbank
|
pushd data/fbank
|
||||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_DEV.jsonl.gz) .
|
ln -svf $(realpath ../ASR/data/fbank/cuts_DEV.jsonl.gz) .
|
||||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_DEV.lca) .
|
ln -svf $(realpath ../ASR/data/fbank/feats_DEV.lca) .
|
||||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_TEST.jsonl.gz) .
|
ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .
|
||||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_TEST.lca) .
|
ln -svf $(realpath ../ASR/data/fbank/feats_TEST_NET.lca) .
|
||||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_L.jsonl.gz) .
|
ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
|
||||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_L.lca) .
|
ln -svf $(realpath ../ASR/data/fbank/feats_TEST_MEETING.lca) .
|
||||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_M.jsonl.gz) .
|
ln -svf $(realpath ../ASR/data/fbank/cuts_L.jsonl.gz) .
|
||||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_M.lca) .
|
ln -svf $(realpath ../ASR/data/fbank/L_split_1000) .
|
||||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_S.jsonl.gz) .
|
ln -svf $(realpath ../ASR/data/fbank/cuts_M.jsonl.gz) .
|
||||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_S.lca) .
|
ln -svf $(realpath ../ASR/data/fbank/M_split_1000) .
|
||||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_XS.jsonl.gz) .
|
ln -svf $(realpath ../ASR/data/fbank/cuts_S.jsonl.gz) .
|
||||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_XS.lca) .
|
ln -svf $(realpath ../ASR/data/fbank/S_split_1000) .
|
||||||
ln -svf $(realpath ../ASR/data/fbank/XL_split) .
|
|
||||||
ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) .
|
ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) .
|
||||||
ln -svf $(realpath ../ASR/data/fbank/musan_feats) .
|
ln -svf $(realpath ../ASR/data/fbank/musan_feats) .
|
||||||
popd
|
popd
|
||||||
pushd data
|
pushd data
|
||||||
ln -svf $(realpath ../ASR/data/lang_bpe_500) .
|
ln -svf $(realpath ../ASR/data/lang_partial_tone) .
|
||||||
popd
|
popd
|
||||||
|
touch data/fbank/.wewetspeech.done
|
||||||
else
|
else
|
||||||
log "Gigaspeech dataset already exists, skipping."
|
log "WenetSpeech dataset already exists, skipping."
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
log "Stage 1: Prepare open commands dataset."
|
log "Stage 1: Prepare open commands dataset."
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
if [ ! -e data/fbank/.fluent_speech_commands.done ]; then
|
if [ ! -e data/fbank/.cn_speech_commands.done ]; then
|
||||||
pushd data
|
pushd data
|
||||||
git clone https://github.com/pkufool/open-commands.git
|
git clone https://github.com/pkufool/open-commands.git
|
||||||
ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt
|
ln -svf $(realpath ./open-commands/CN/small/commands.txt) commands_small.txt
|
||||||
ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt
|
ln -svf $(realpath ./open-commands/CN/large/commands.txt) commands_large.txt
|
||||||
pushd open-commands
|
pushd open-commands
|
||||||
./script/prepare.sh --stage 3 --stop-stage 3
|
./script/prepare.sh --stage 1 --stop-stage 1
|
||||||
./script/prepare.sh --stage 6 --stop-stage 6
|
./script/prepare.sh --stage 3 --stop-stage 5
|
||||||
popd
|
popd
|
||||||
popd
|
popd
|
||||||
pushd data/fbank
|
pushd data/fbank
|
||||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_large.jsonl.gz) .
|
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_large.jsonl.gz) .
|
||||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_large) .
|
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_large) .
|
||||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_small.jsonl.gz) .
|
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_small.jsonl.gz) .
|
||||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_small) .
|
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_small) .
|
||||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_valid.jsonl.gz) .
|
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_dev.jsonl.gz) .
|
||||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_valid) .
|
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_dev) .
|
||||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_train.jsonl.gz) .
|
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_test.jsonl.gz) .
|
||||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_train) .
|
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_test) .
|
||||||
|
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_train.jsonl.gz) .
|
||||||
|
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_train) .
|
||||||
|
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_clean.jsonl.gz) .
|
||||||
|
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_clean.lca) .
|
||||||
|
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_noisy.jsonl.gz) .
|
||||||
|
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_noisy.lca) .
|
||||||
popd
|
popd
|
||||||
touch data/fbank/.fluent_speech_commands.done
|
touch data/fbank/.cn_speech_commands.done
|
||||||
else
|
else
|
||||||
log "Fluent speech commands dataset already exists, skipping."
|
log "CN speech commands dataset already exists, skipping."
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -37,7 +37,7 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
log "Stage 0: Train a model."
|
log "Stage 0: Train a model."
|
||||||
if [ ! -e data/fbank/.gigaspeech.done ]; then
|
if [ ! -e data/fbank/.wenetspeech.done ]; then
|
||||||
log "You need to run the prepare.sh first."
|
log "You need to run the prepare.sh first."
|
||||||
exit -1
|
exit -1
|
||||||
fi
|
fi
|
||||||
|
@ -19,38 +19,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) fast beam search (LG)
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 35 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--lang-dir data/lang_char \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method greedy_search
|
|
||||||
|
|
||||||
(2) modified beam search
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 35 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--lang-dir data/lang_char \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method modified_beam_search \
|
|
||||||
--beam-size 4
|
|
||||||
|
|
||||||
(3) fast beam search (trivial_graph)
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 35 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--lang-dir data/lang_char \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method fast_beam_search \
|
|
||||||
--beam 20.0 \
|
|
||||||
--max-contexts 8 \
|
|
||||||
--max-states 64
|
|
||||||
|
|
||||||
(4) fast beam search (LG)
|
|
||||||
./zipformer/decode.py \
|
./zipformer/decode.py \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
@ -61,20 +30,6 @@ Usage:
|
|||||||
--beam 20.0 \
|
--beam 20.0 \
|
||||||
--max-contexts 8 \
|
--max-contexts 8 \
|
||||||
--max-states 64
|
--max-states 64
|
||||||
|
|
||||||
(5) fast beam search (nbest oracle WER)
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 35 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--lang-dir data/lang_char \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method fast_beam_search_nbest_oracle \
|
|
||||||
--beam 20.0 \
|
|
||||||
--max-contexts 8 \
|
|
||||||
--max-states 64 \
|
|
||||||
--num-paths 200 \
|
|
||||||
--nbest-scale 0.5
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,19 +17,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
|
||||||
Usage:
|
|
||||||
(2) modified beam search
|
|
||||||
./zipformer/decode.py \
|
|
||||||
--epoch 35 \
|
|
||||||
--avg 15 \
|
|
||||||
--exp-dir ./zipformer/exp \
|
|
||||||
--lang-dir data/lang_char \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method modified_beam_search \
|
|
||||||
--beam-size 4
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
@ -207,13 +194,6 @@ def get_parser():
|
|||||||
help="The default threshold (probability) to trigger the keyword.",
|
help="The default threshold (probability) to trigger the keyword.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--keywords-version",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="The keywords configuration version, just to save results to different files.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-tailing-blanks",
|
"--num-tailing-blanks",
|
||||||
type=int,
|
type=int,
|
||||||
@ -479,7 +459,7 @@ def save_results(
|
|||||||
s += f"\tPrecision: {precision:.3f}\n"
|
s += f"\tPrecision: {precision:.3f}\n"
|
||||||
s += f"\tRecall(PPR): {recall:.3f}\n"
|
s += f"\tRecall(PPR): {recall:.3f}\n"
|
||||||
s += f"\tFPR: {fpr:.3f}\n"
|
s += f"\tFPR: {fpr:.3f}\n"
|
||||||
s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n"
|
s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n"
|
||||||
if key != "all":
|
if key != "all":
|
||||||
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
|
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
|
||||||
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
|
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
|
||||||
@ -525,7 +505,7 @@ def main():
|
|||||||
params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
|
params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
|
||||||
if params.blank_penalty != 0:
|
if params.blank_penalty != 0:
|
||||||
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||||
params.suffix += f"-version-{params.keywords_version}"
|
params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user