add model export scripts

This commit is contained in:
pkufool 2024-02-06 17:23:27 +08:00
parent 91f13826d7
commit 63c6dd90f5
10 changed files with 1192 additions and 1 deletions

84
egs/gigaspeech/KWS/prepare.sh Executable file
View File

@ -0,0 +1,84 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=15
stage=0
stop_stage=100
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Prepare gigaspeech dataset."
mkdir -p data/fbank
if [ ! -e data/fbank/.gigaspeech.done ]; then
pushd ../ASR
./prepare.sh --stage 0 --stop-stage 9
./prepare.sh --stage 11 --stop-stage 11
popd
pushd data/fbank
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_DEV.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_DEV.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_TEST.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_TEST.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_L.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_L.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_M.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_M.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_S.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_S.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_XS.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_XS.lca) .
ln -svf $(realpath ../ASR/data/fbank/XL_split) .
ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/musan_feats) .
popd
pushd data
ln -svf $(realpath ../ASR/data/lang_bpe_500) .
popd
else
log "Gigaspeech dataset already exists, skipping."
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare open commands dataset."
mkdir -p data/fbank
if [ ! -e data/fbank/.fluent_speech_commands.done ]; then
pushd data
git clone https://github.com/pkufool/open-commands.git
ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt
ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt
pushd open-commands
./script/prepare.sh --stage 3 --stop-stage 3
./script/prepare.sh --stage 6 --stop-stage 6
popd
popd
pushd data/fbank
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_large.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_large) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_small.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_small) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_valid.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_valid) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_train.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_train) .
popd
touch data/fbank/.fluent_speech_commands.done
else
log "Fluent speech commands dataset already exists, skipping."
fi
fi

202
egs/gigaspeech/KWS/run.sh Normal file
View File

@ -0,0 +1,202 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export PYTHONPATH=../../../:$PYTHONPATH
stage=0
stop_stage=100
pre_trained_model_host=github
. shared/parse_options.sh || exit 1
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: Download a pre-trained model."
if [ $pre_trained_model_host -eq "github" ]; then
elif [$pre_trained_model_host -eq "modelscope" ]; then
else
log "Pretrained model host : $pre_trained_model_host not support."
exit -1;
fi
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Train a model."
if [ ! -e data/fbank/.gigaspeech.done ]; then
log "You need to run the prepare.sh first."
exit -1
fi
python ./zipformer/train.py \
--world-size 4 \
--exp-dir zipformer/exp \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--num-epochs 15 \
--lr-epochs 1.5 \
--use-fp16 1 \
--start-epoch 1 \
--training-subset L \
--pinyin-type partial_with_tone \
--causal 1 \
--lang-dir data/lang_partial_tone \
--max-duration 1000
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Decode the model."
for t in small, large; do
python ./zipformer/decode.py \
--epoch 15 \
--avg 2 \
--exp-dir ./zipformer/exp \
--lang-dir ./data/lang_partial_tone \
--pinyin-type partial_with_tone \
--causal 1 \
--chunk-size 16 \
--left-context-frames 64 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--test-set $t \
--keywords-score 1.0 \
--keywords-threshold 0.35 \
--keywords-file ./data/commands_${t}.txt \
--max-duration 3000
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Export the model."
python ./zipformer/export.py \
--epoch 15 \
--avg 2 \
--exp-dir ./zipformer/exp \
--tokens data/lang_partial_tone/tokens.txt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 64 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128
python ./zipformer/export_onnx_streaming.py \
--exp-dir zipformer/exp \
--tokens data/lang_partial_tone/tokens.txt \
--epoch 15 \
--avg 2 \
--chunk-size 16 \
--left-context-frames 128 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--causal 1
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 2: Finetune the model"
# The following configuration of lr schedule should work well
# You may also tune the following parameters to adjust learning rate schedule
base_lr=0.0005
lr_epochs=100
lr_batches=100000
# We recommend to start from an averaged model
finetune_ckpt=zipformer/exp/pretrained.pt
./zipformer/finetune.py \
--world-size 4 \
--num-epochs 10 \
--start-epoch 1 \
--exp-dir zipformer/exp_finetune
--lang-dir ./data/lang_partial_tone \
--pinyin-type partial_with_tone \
--use-fp16 1 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--causal 1 \
--base-lr $base_lr \
--lr-epochs $lr_epochs \
--lr-batches $lr_batches \
--finetune-ckpt $finetune_ckpt \
--max-duration 1500
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 1: Decode the finetuned model."
for t in small, large; do
python ./zipformer/decode.py \
--epoch 15 \
--avg 2 \
--exp-dir ./zipformer/exp_finetune \
--lang-dir ./data/lang_partial_tone \
--pinyin-type partial_with_tone \
--causal 1 \
--chunk-size 16 \
--left-context-frames 64 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--test-set $t \
--keywords-score 1.0 \
--keywords-threshold 0.35 \
--keywords-file ./data/commands_${t}.txt \
--max-duration 3000
done
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 2: Export the finetuned model."
python ./zipformer/export_onnx_streaming.py \
--exp-dir zipformer/exp_finetune \
--tokens data/lang_partial_tone/tokens.txt \
--epoch 15 \
--avg 2 \
--chunk-size 16 \
--left-context-frames 128 \
--decoder-dim 320 \
--joiner-dim 320 \
--num-encoder-layers 1,1,1,1,1,1 \
--feedforward-dim 192,192,192,192,192,192 \
--encoder-dim 128,128,128,128,128,128 \
--encoder-unmasked-dim 128,128,128,128,128,128 \
--causal 1
fi

1
egs/gigaspeech/KWS/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx-streaming.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export.py

84
egs/wenetspeech/KWS/prepare.sh Executable file
View File

@ -0,0 +1,84 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=15
stage=0
stop_stage=100
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Prepare gigaspeech dataset."
mkdir -p data/fbank
if [ ! -e data/fbank/.gigaspeech.done ]; then
pushd ../ASR
./prepare.sh --stage 0 --stop-stage 9
./prepare.sh --stage 11 --stop-stage 11
popd
pushd data/fbank
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_DEV.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_DEV.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_TEST.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_TEST.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_L.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_L.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_M.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_M.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_S.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_S.lca) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_XS.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_XS.lca) .
ln -svf $(realpath ../ASR/data/fbank/XL_split) .
ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) .
ln -svf $(realpath ../ASR/data/fbank/musan_feats) .
popd
pushd data
ln -svf $(realpath ../ASR/data/lang_bpe_500) .
popd
else
log "Gigaspeech dataset already exists, skipping."
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare open commands dataset."
mkdir -p data/fbank
if [ ! -e data/fbank/.fluent_speech_commands.done ]; then
pushd data
git clone https://github.com/pkufool/open-commands.git
ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt
ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt
pushd open-commands
./script/prepare.sh --stage 3 --stop-stage 3
./script/prepare.sh --stage 6 --stop-stage 6
popd
popd
pushd data/fbank
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_large.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_large) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_small.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_small) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_valid.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_valid) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_train.jsonl.gz) .
ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_train) .
popd
touch data/fbank/.fluent_speech_commands.done
else
log "Fluent speech commands dataset already exists, skipping."
fi
fi

View File

@ -24,12 +24,17 @@ log() {
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: Download a pre-trained model."
if [ $pre_trained_model_host -eq "github" ]; then
elif [$pre_trained_model_host -eq "modelscope" ]; then
else
log "Pretrained model host : $pre_trained_model_host not support."
exit -1;
fi
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Train a model."
if [ ! -e data/fbank/.gigaspeech.done ]; then

View File

@ -0,0 +1,812 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method greedy_search
(2) modified beam search
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(3) fast beam search (trivial_graph)
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(4) fast beam search (LG)
./zipformer/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest oracle WER)
./zipformer/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./zipformer/exp \
--lang-dir data/lang_char \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from lhotse.cut import Cut
from train import add_model_arguments, get_model, get_params
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_char",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest_oracle
If you use fast_beam_search_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search, fast_beam_search_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--ilme-scale",
type=float,
default=0.2,
help="""
Used only when --decoding_method is fast_beam_search_LG.
It specifies the scale for the internal language model estimation.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search, fast_beam_search_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search, fast_beam_search_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
pad_len = 30
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
)
x, x_lens = model.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "fast_beam_search_LG":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
ilme_scale=params.ilme_scale,
)
for hyp in hyp_tokens:
sentence = "".join([lexicon.word_table[i] for i in hyp])
hyps.append(list(sentence))
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=graph_compiler.texts_to_ids(supervisions["text"]),
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
blank_penalty=params.blank_penalty,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
blank_penalty=params.blank_penalty,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[idx] for idx in hyp])
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search":
return {"greedy_search_" + key: hyps}
elif "fast_beam_search" in params.decoding_method:
key += f"_beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ilme_scale_{params.ilme_scale}"
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}_" + key: hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
graph_compiler: CharCtcTrainingGraphCompiler,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list("".join(text.split())) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest_oracle",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"_ilme_scale_{params.ilme_scale}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-blank-penalty-{params.blank_penalty}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args)
def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2
if T <= 0:
logging.warning(
f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
)
return T > 0
dev_cuts = wenetspeech.valid_cuts()
dev_cuts = dev_cuts.filter(remove_short_utt)
dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
test_net_cuts = wenetspeech.test_net_cuts()
test_net_cuts = test_net_cuts.filter(remove_short_utt)
test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt)
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dls = [dev_dl, test_net_dl, test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
graph_compiler=graph_compiler,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx-streaming.py