mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Merge branch 'dev_export_diff_acoustic_units' of https://github.com/JinZr/icefall into dev_export_diff_acoustic_units
This commit is contained in:
commit
7242d9b5f6
@ -40,7 +40,7 @@ for m in ctc-decoding 1best; do
|
|||||||
--model-filename $repo/exp/jit_trace.pt \
|
--model-filename $repo/exp/jit_trace.pt \
|
||||||
--words-file $repo/data/lang_bpe_500/words.txt \
|
--words-file $repo/data/lang_bpe_500/words.txt \
|
||||||
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
--HLG $repo/data/lang_bpe_500/HLG.pt \
|
||||||
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
--G $repo/data/lm/G_4_gram.pt \
|
--G $repo/data/lm/G_4_gram.pt \
|
||||||
--method $m \
|
--method $m \
|
||||||
--sample-rate 16000 \
|
--sample-rate 16000 \
|
||||||
|
12
.github/scripts/test-ncnn-export.sh
vendored
12
.github/scripts/test-ncnn-export.sh
vendored
@ -45,7 +45,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -56,11 +55,10 @@ log "Export via torch.jit.trace()"
|
|||||||
|
|
||||||
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
\
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--num-encoder-layers 12 \
|
--num-encoder-layers 12 \
|
||||||
--chunk-length 32 \
|
--chunk-length 32 \
|
||||||
--cnn-module-kernel 31 \
|
--cnn-module-kernel 31 \
|
||||||
@ -91,7 +89,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
|
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -102,7 +99,7 @@ log "Export via torch.jit.trace()"
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export-for-ncnn.py \
|
./lstm_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0
|
--use-averaged-model 0
|
||||||
@ -140,7 +137,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
|||||||
repo=$(basename $repo_url)
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
pushd $repo
|
pushd $repo
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
cd exp
|
cd exp
|
||||||
@ -148,7 +144,7 @@ ln -s pretrained.pt epoch-99.pt
|
|||||||
popd
|
popd
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
@ -199,7 +195,7 @@ ln -s pretrained.pt epoch-9999.pt
|
|||||||
popd
|
popd
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
|
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
|
||||||
--lang-dir $repo/data/lang_char_bpe \
|
--tokens $repo/data/lang_char_bpe/tokens.txt \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
|
16
.github/scripts/test-onnx-export.sh
vendored
16
.github/scripts/test-onnx-export.sh
vendored
@ -155,7 +155,7 @@ log "Export via torch.jit.trace()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
./pruned_transducer_stateless7_streaming/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -204,7 +204,7 @@ popd
|
|||||||
log "Export via torch.jit.script()"
|
log "Export via torch.jit.script()"
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export.py \
|
./pruned_transducer_stateless3/export.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--exp-dir $repo/exp/ \
|
--exp-dir $repo/exp/ \
|
||||||
@ -213,7 +213,7 @@ log "Export via torch.jit.script()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless3/export-onnx.py \
|
./pruned_transducer_stateless3/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--exp-dir $repo/exp/
|
--exp-dir $repo/exp/
|
||||||
@ -258,7 +258,7 @@ popd
|
|||||||
log "Export via torch.jit.script()"
|
log "Export via torch.jit.script()"
|
||||||
|
|
||||||
./pruned_transducer_stateless5/export.py \
|
./pruned_transducer_stateless5/export.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
@ -274,7 +274,7 @@ log "Export via torch.jit.script()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./pruned_transducer_stateless5/export-onnx.py \
|
./pruned_transducer_stateless5/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
@ -384,7 +384,7 @@ popd
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./conv_emformer_transducer_stateless2/export-onnx.py \
|
./conv_emformer_transducer_stateless2/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -424,7 +424,7 @@ popd
|
|||||||
log "Export via torch.jit.trace()"
|
log "Export via torch.jit.trace()"
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
./lstm_transducer_stateless2/export.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -434,7 +434,7 @@ log "Export via torch.jit.trace()"
|
|||||||
log "Test exporting to ONNX format"
|
log "Test exporting to ONNX format"
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export-onnx.py \
|
./lstm_transducer_stateless2/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
|
@ -139,8 +139,8 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
|
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
|
||||||
@ -154,7 +154,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -211,10 +211,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -675,12 +675,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -410,10 +410,20 @@ def main():
|
|||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
if params.method == "ctc-decoding":
|
||||||
words = " ".join(hyp)
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = words.replace("▁", " ").strip()
|
words = "".join(hyp)
|
||||||
s += f"{filename}:\n{words}\n\n"
|
words = words.replace("▁", " ").strip()
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
elif params.method in [
|
||||||
|
"1best",
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
]:
|
||||||
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
|
words = " ".join(hyp)
|
||||||
|
words = words.replace("▁", " ").strip()
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
@ -274,7 +274,7 @@ def main():
|
|||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
params.vocab_size = num_tokens(token_table)
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for blank
|
||||||
params.blank_id = token_table["<blk>"]
|
params.blank_id = token_table["<blk>"]
|
||||||
assert params.blank_id == 0
|
assert params.blank_id == 0
|
||||||
|
|
||||||
@ -429,10 +429,20 @@ def main():
|
|||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
if params.method == "ctc-decoding":
|
||||||
words = " ".join(hyp)
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = words.replace("▁", " ").strip()
|
words = "".join(hyp)
|
||||||
s += f"{filename}:\n{words}\n\n"
|
words = words.replace("▁", " ").strip()
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
elif params.method in [
|
||||||
|
"1best",
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
]:
|
||||||
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
|
words = " ".join(hyp)
|
||||||
|
words = words.replace("▁", " ").strip()
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user