Merge branch 'dev_export_diff_acoustic_units' of https://github.com/JinZr/icefall into dev_export_diff_acoustic_units

This commit is contained in:
JinZr 2023-08-12 16:02:43 +08:00
commit 7242d9b5f6
6 changed files with 53 additions and 35 deletions

View File

@ -40,7 +40,7 @@ for m in ctc-decoding 1best; do
--model-filename $repo/exp/jit_trace.pt \
--words-file $repo/data/lang_bpe_500/words.txt \
--HLG $repo/data/lang_bpe_500/HLG.pt \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--G $repo/data/lm/G_4_gram.pt \
--method $m \
--sample-rate 16000 \

View File

@ -45,7 +45,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
cd exp
@ -56,11 +55,10 @@ log "Export via torch.jit.trace()"
./conv_emformer_transducer_stateless2/export-for-ncnn.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
\
--tokens $repo/data/lang_bpe_500/tokens.txt \
--num-encoder-layers 12 \
--chunk-length 32 \
--cnn-module-kernel 31 \
@ -91,7 +89,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
cd exp
@ -102,7 +99,7 @@ log "Export via torch.jit.trace()"
./lstm_transducer_stateless2/export-for-ncnn.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--use-averaged-model 0
@ -140,7 +137,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
@ -148,7 +144,7 @@ ln -s pretrained.pt epoch-99.pt
popd
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
@ -199,7 +195,7 @@ ln -s pretrained.pt epoch-9999.pt
popd
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
--lang-dir $repo/data/lang_char_bpe \
--tokens $repo/data/lang_char_bpe/tokens.txt \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 9999 \

View File

@ -155,7 +155,7 @@ log "Export via torch.jit.trace()"
log "Test exporting to ONNX format"
./pruned_transducer_stateless7_streaming/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -204,7 +204,7 @@ popd
log "Export via torch.jit.script()"
./pruned_transducer_stateless3/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 9999 \
--avg 1 \
--exp-dir $repo/exp/ \
@ -213,7 +213,7 @@ log "Export via torch.jit.script()"
log "Test exporting to ONNX format"
./pruned_transducer_stateless3/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 9999 \
--avg 1 \
--exp-dir $repo/exp/
@ -258,7 +258,7 @@ popd
log "Export via torch.jit.script()"
./pruned_transducer_stateless5/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
@ -274,7 +274,7 @@ log "Export via torch.jit.script()"
log "Test exporting to ONNX format"
./pruned_transducer_stateless5/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
@ -384,7 +384,7 @@ popd
log "Test exporting to ONNX format"
./conv_emformer_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -424,7 +424,7 @@ popd
log "Export via torch.jit.trace()"
./lstm_transducer_stateless2/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -434,7 +434,7 @@ log "Export via torch.jit.trace()"
log "Test exporting to ONNX format"
./lstm_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \

View File

@ -139,8 +139,8 @@ import argparse
import logging
from pathlib import Path
import k2
import onnxruntime
import sentencepiece as spm
import torch
import torch.nn as nn
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
@ -154,7 +154,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -211,10 +211,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
@ -675,12 +675,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -410,10 +410,20 @@ def main():
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
if params.method == "ctc-decoding":
for filename, hyp in zip(params.sound_files, hyps):
words = "".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
elif params.method in [
"1best",
"nbest-rescoring",
"whole-lattice-rescoring",
]:
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -274,7 +274,7 @@ def main():
params.update(vars(args))
token_table = k2.SymbolTable.from_file(params.tokens)
params.vocab_size = num_tokens(token_table)
params.vocab_size = num_tokens(token_table) + 1 # +1 for blank
params.blank_id = token_table["<blk>"]
assert params.blank_id == 0
@ -429,10 +429,20 @@ def main():
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
if params.method == "ctc-decoding":
for filename, hyp in zip(params.sound_files, hyps):
words = "".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
elif params.method in [
"1best",
"nbest-rescoring",
"whole-lattice-rescoring",
]:
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")