mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-02 21:54:18 +00:00
Fix wewetspeech prepare.sh
This commit is contained in:
parent
afe3b183c4
commit
7d91e8b6d5
@ -49,6 +49,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
pushd data
|
||||
ln -svf $(realpath ../ASR/data/lang_bpe_500) .
|
||||
popd
|
||||
touch data/fbank/.gigaspeech.done
|
||||
else
|
||||
log "Gigaspeech dataset already exists, skipping."
|
||||
fi
|
||||
@ -63,7 +64,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt
|
||||
ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt
|
||||
pushd open-commands
|
||||
./script/prepare.sh --stage 3 --stop-stage 3
|
||||
./script/prepare.sh --stage 2 --stop-stage 2
|
||||
./script/prepare.sh --stage 6 --stop-stage 6
|
||||
popd
|
||||
popd
|
||||
|
@ -186,13 +186,6 @@ def get_parser():
|
||||
help="The default threshold (probability) to trigger the keyword.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keywords-version",
|
||||
type=str,
|
||||
default="",
|
||||
help="The keywords configuration version, just to save results to different files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-tailing-blanks",
|
||||
type=int,
|
||||
@ -222,7 +215,7 @@ def decode_one_batch(
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
kws_graph: Optional[ContextGraph] = None,
|
||||
keywords_graph: Optional[ContextGraph] = None,
|
||||
) -> List[List[Tuple[str, Tuple[int, int]]]]:
|
||||
"""Decode one batch and return the result in a list.
|
||||
|
||||
@ -242,7 +235,7 @@ def decode_one_batch(
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
kws_graph:
|
||||
keywords_graph:
|
||||
The graph containing keywords.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
@ -274,7 +267,7 @@ def decode_one_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
keywords_graph=kws_graph,
|
||||
keywords_graph=keywords_graph,
|
||||
beam=params.beam,
|
||||
num_tailing_blanks=params.num_tailing_blanks,
|
||||
blank_penalty=params.blank_penalty,
|
||||
@ -295,7 +288,7 @@ def decode_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
kws_graph: ContextGraph,
|
||||
keywords_graph: ContextGraph,
|
||||
keywords: Set[str],
|
||||
test_only_keywords: bool,
|
||||
) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]:
|
||||
@ -310,7 +303,7 @@ def decode_dataset(
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
kws_graph:
|
||||
keywords_graph:
|
||||
The graph containing keywords.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
@ -341,7 +334,7 @@ def decode_dataset(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
kws_graph=kws_graph,
|
||||
keywords_graph=keywords_graph,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
@ -459,7 +452,7 @@ def save_results(
|
||||
s += f"\tPrecision: {precision:.3f}\n"
|
||||
s += f"\tRecall(PPR): {recall:.3f}\n"
|
||||
s += f"\tFPR: {fpr:.3f}\n"
|
||||
s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n"
|
||||
s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n"
|
||||
if key != "all":
|
||||
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
|
||||
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
|
||||
@ -505,7 +498,7 @@ def main():
|
||||
params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
|
||||
if params.blank_penalty != 0:
|
||||
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||
params.suffix += f"-version-{params.keywords_version}"
|
||||
params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
@ -555,10 +548,10 @@ def main():
|
||||
|
||||
params.keywords_config = "".join(keywords_config)
|
||||
|
||||
kws_graph = ContextGraph(
|
||||
keywords_graph = ContextGraph(
|
||||
context_score=params.keywords_score, ac_threshold=params.keywords_threshold
|
||||
)
|
||||
kws_graph.build(
|
||||
keywords_graph.build(
|
||||
token_ids=token_ids,
|
||||
phrases=phrases,
|
||||
scores=keywords_scores,
|
||||
@ -677,7 +670,7 @@ def main():
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
kws_graph=kws_graph,
|
||||
keywords_graph=keywords_graph,
|
||||
keywords=keywords,
|
||||
test_only_keywords="fsc" in test_set,
|
||||
)
|
||||
|
@ -276,7 +276,7 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
def add_training_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
|
@ -993,7 +993,7 @@ def keywords_search(
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
assert context_graph is not None
|
||||
assert keywords_graph is not None
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
@ -1018,7 +1018,7 @@ def keywords_search(
|
||||
Hypothesis(
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
context_state=context_graph.root,
|
||||
context_state=keywords_graph.root,
|
||||
timestamp=[],
|
||||
ac_probs=[],
|
||||
)
|
||||
@ -1125,7 +1125,7 @@ def keywords_search(
|
||||
context_score,
|
||||
new_context_state,
|
||||
_,
|
||||
) = context_graph.forward_one_step(hyp.context_state, new_token)
|
||||
) = keywords_graph.forward_one_step(hyp.context_state, new_token)
|
||||
new_num_tailing_blanks = 0
|
||||
if new_context_state.token == -1: # root
|
||||
new_ys[-context_size:] = [-1] * (context_size - 1) + [blank_id]
|
||||
@ -1143,7 +1143,7 @@ def keywords_search(
|
||||
B[i].add(new_hyp)
|
||||
|
||||
top_hyp = B[i].get_most_probable(length_norm=True)
|
||||
matched, matched_state = context_graph.is_matched(top_hyp.context_state)
|
||||
matched, matched_state = keywords_graph.is_matched(top_hyp.context_state)
|
||||
if matched:
|
||||
ac_prob = (
|
||||
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
|
||||
@ -1164,7 +1164,7 @@ def keywords_search(
|
||||
Hypothesis(
|
||||
ys=[-1] * (context_size - 1) + [blank_id],
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
context_state=context_graph.root,
|
||||
context_state=keywords_graph.root,
|
||||
timestamp=[],
|
||||
ac_probs=[],
|
||||
)
|
||||
@ -1174,7 +1174,7 @@ def keywords_search(
|
||||
|
||||
for i, hyps in enumerate(B):
|
||||
top_hyp = hyps.get_most_probable(length_norm=True)
|
||||
matched, matched_state = context_graph.is_matched(top_hyp.context_state)
|
||||
matched, matched_state = keywords_graph.is_matched(top_hyp.context_state)
|
||||
if matched:
|
||||
ac_prob = (
|
||||
sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level
|
||||
|
@ -376,5 +376,6 @@ if [ $stage -le 22 ] && [ $stop_stage -ge 22 ]; then
|
||||
--token-type $token \
|
||||
--lang-dir $lang_dir
|
||||
fi
|
||||
python ./local/compile_lg.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
@ -22,63 +22,69 @@ log() {
|
||||
}
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Prepare gigaspeech dataset."
|
||||
log "Stage 0: Prepare wewetspeech dataset."
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.gigaspeech.done ]; then
|
||||
if [ ! -e data/fbank/.wewetspeech.done ]; then
|
||||
pushd ../ASR
|
||||
./prepare.sh --stage 0 --stop-stage 9
|
||||
./prepare.sh --stage 11 --stop-stage 11
|
||||
./prepare.sh --stage 0 --stop-stage 17
|
||||
./prepare.sh --stage 22 --stop-stage 22
|
||||
popd
|
||||
pushd data/fbank
|
||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_DEV.jsonl.gz) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_DEV.lca) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_TEST.jsonl.gz) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_TEST.lca) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_L.jsonl.gz) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_L.lca) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_M.jsonl.gz) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_M.lca) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_S.jsonl.gz) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_S.lca) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_XS.jsonl.gz) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_XS.lca) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/XL_split) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/cuts_DEV.jsonl.gz) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/feats_DEV.lca) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/feats_TEST_NET.lca) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/feats_TEST_MEETING.lca) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/cuts_L.jsonl.gz) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/L_split_1000) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/cuts_M.jsonl.gz) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/M_split_1000) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/cuts_S.jsonl.gz) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/S_split_1000) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) .
|
||||
ln -svf $(realpath ../ASR/data/fbank/musan_feats) .
|
||||
popd
|
||||
pushd data
|
||||
ln -svf $(realpath ../ASR/data/lang_bpe_500) .
|
||||
ln -svf $(realpath ../ASR/data/lang_partial_tone) .
|
||||
popd
|
||||
touch data/fbank/.wewetspeech.done
|
||||
else
|
||||
log "Gigaspeech dataset already exists, skipping."
|
||||
log "WenetSpeech dataset already exists, skipping."
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare open commands dataset."
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.fluent_speech_commands.done ]; then
|
||||
if [ ! -e data/fbank/.cn_speech_commands.done ]; then
|
||||
pushd data
|
||||
git clone https://github.com/pkufool/open-commands.git
|
||||
ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt
|
||||
ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt
|
||||
ln -svf $(realpath ./open-commands/CN/small/commands.txt) commands_small.txt
|
||||
ln -svf $(realpath ./open-commands/CN/large/commands.txt) commands_large.txt
|
||||
pushd open-commands
|
||||
./script/prepare.sh --stage 3 --stop-stage 3
|
||||
./script/prepare.sh --stage 6 --stop-stage 6
|
||||
./script/prepare.sh --stage 1 --stop-stage 1
|
||||
./script/prepare.sh --stage 3 --stop-stage 5
|
||||
popd
|
||||
popd
|
||||
pushd data/fbank
|
||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_large.jsonl.gz) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_large) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_small.jsonl.gz) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_small) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_valid.jsonl.gz) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_valid) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_train.jsonl.gz) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_train) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_large.jsonl.gz) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_large) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_small.jsonl.gz) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_small) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_dev.jsonl.gz) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_dev) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_test.jsonl.gz) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_test) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_train.jsonl.gz) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_train) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_clean.jsonl.gz) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_clean.lca) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_noisy.jsonl.gz) .
|
||||
ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_noisy.lca) .
|
||||
popd
|
||||
touch data/fbank/.fluent_speech_commands.done
|
||||
touch data/fbank/.cn_speech_commands.done
|
||||
else
|
||||
log "Fluent speech commands dataset already exists, skipping."
|
||||
log "CN speech commands dataset already exists, skipping."
|
||||
fi
|
||||
fi
|
||||
|
@ -37,7 +37,7 @@ fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Train a model."
|
||||
if [ ! -e data/fbank/.gigaspeech.done ]; then
|
||||
if [ ! -e data/fbank/.wenetspeech.done ]; then
|
||||
log "You need to run the prepare.sh first."
|
||||
exit -1
|
||||
fi
|
||||
|
@ -19,38 +19,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./zipformer/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) modified beam search
|
||||
./zipformer/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) fast beam search (trivial_graph)
|
||||
./zipformer/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(4) fast beam search (LG)
|
||||
(1) fast beam search (LG)
|
||||
./zipformer/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
@ -61,20 +30,6 @@ Usage:
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest oracle WER)
|
||||
./zipformer/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
"""
|
||||
|
||||
|
||||
|
@ -17,19 +17,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(2) modified beam search
|
||||
./zipformer/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
@ -207,13 +194,6 @@ def get_parser():
|
||||
help="The default threshold (probability) to trigger the keyword.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keywords-version",
|
||||
type=str,
|
||||
default="",
|
||||
help="The keywords configuration version, just to save results to different files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-tailing-blanks",
|
||||
type=int,
|
||||
@ -479,7 +459,7 @@ def save_results(
|
||||
s += f"\tPrecision: {precision:.3f}\n"
|
||||
s += f"\tRecall(PPR): {recall:.3f}\n"
|
||||
s += f"\tFPR: {fpr:.3f}\n"
|
||||
s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n"
|
||||
s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n"
|
||||
if key != "all":
|
||||
s += f"\tTP list: {' # '.join(item.TP_list)}\n"
|
||||
s += f"\tFP list: {' # '.join(item.FP_list)}\n"
|
||||
@ -525,7 +505,7 @@ def main():
|
||||
params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}"
|
||||
if params.blank_penalty != 0:
|
||||
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||
params.suffix += f"-version-{params.keywords_version}"
|
||||
params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
Loading…
x
Reference in New Issue
Block a user