Merge branch 'dev_export_diff_acoustic_units' of https://github.com/JinZr/icefall into dev_export_diff_acoustic_units

This commit is contained in:
JinZr 2023-08-12 16:02:43 +08:00
commit 7242d9b5f6
6 changed files with 53 additions and 35 deletions

View File

@ -40,7 +40,7 @@ for m in ctc-decoding 1best; do
--model-filename $repo/exp/jit_trace.pt \ --model-filename $repo/exp/jit_trace.pt \
--words-file $repo/data/lang_bpe_500/words.txt \ --words-file $repo/data/lang_bpe_500/words.txt \
--HLG $repo/data/lang_bpe_500/HLG.pt \ --HLG $repo/data/lang_bpe_500/HLG.pt \
--tokens $repo/data/lang_bpe_500/tokens.txt \ --bpe-model $repo/data/lang_bpe_500/bpe.model \
--G $repo/data/lm/G_4_gram.pt \ --G $repo/data/lm/G_4_gram.pt \
--method $m \ --method $m \
--sample-rate 16000 \ --sample-rate 16000 \

View File

@ -45,7 +45,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url) repo=$(basename $repo_url)
pushd $repo pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt"
cd exp cd exp
@ -56,11 +55,10 @@ log "Export via torch.jit.trace()"
./conv_emformer_transducer_stateless2/export-for-ncnn.py \ ./conv_emformer_transducer_stateless2/export-for-ncnn.py \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--use-averaged-model 0 \ --use-averaged-model 0 \
\ --tokens $repo/data/lang_bpe_500/tokens.txt \
--num-encoder-layers 12 \ --num-encoder-layers 12 \
--chunk-length 32 \ --chunk-length 32 \
--cnn-module-kernel 31 \ --cnn-module-kernel 31 \
@ -91,7 +89,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url) repo=$(basename $repo_url)
pushd $repo pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt"
cd exp cd exp
@ -102,7 +99,7 @@ log "Export via torch.jit.trace()"
./lstm_transducer_stateless2/export-for-ncnn.py \ ./lstm_transducer_stateless2/export-for-ncnn.py \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--use-averaged-model 0 --use-averaged-model 0
@ -140,7 +137,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url) repo=$(basename $repo_url)
pushd $repo pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt" git lfs pull --include "exp/pretrained.pt"
cd exp cd exp
@ -148,7 +144,7 @@ ln -s pretrained.pt epoch-99.pt
popd popd
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
@ -199,7 +195,7 @@ ln -s pretrained.pt epoch-9999.pt
popd popd
./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ ./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
--lang-dir $repo/data/lang_char_bpe \ --tokens $repo/data/lang_char_bpe/tokens.txt \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 9999 \ --epoch 9999 \

View File

@ -155,7 +155,7 @@ log "Export via torch.jit.trace()"
log "Test exporting to ONNX format" log "Test exporting to ONNX format"
./pruned_transducer_stateless7_streaming/export-onnx.py \ ./pruned_transducer_stateless7_streaming/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
@ -204,7 +204,7 @@ popd
log "Export via torch.jit.script()" log "Export via torch.jit.script()"
./pruned_transducer_stateless3/export.py \ ./pruned_transducer_stateless3/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 9999 \ --epoch 9999 \
--avg 1 \ --avg 1 \
--exp-dir $repo/exp/ \ --exp-dir $repo/exp/ \
@ -213,7 +213,7 @@ log "Export via torch.jit.script()"
log "Test exporting to ONNX format" log "Test exporting to ONNX format"
./pruned_transducer_stateless3/export-onnx.py \ ./pruned_transducer_stateless3/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 9999 \ --epoch 9999 \
--avg 1 \ --avg 1 \
--exp-dir $repo/exp/ --exp-dir $repo/exp/
@ -258,7 +258,7 @@ popd
log "Export via torch.jit.script()" log "Export via torch.jit.script()"
./pruned_transducer_stateless5/export.py \ ./pruned_transducer_stateless5/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--use-averaged-model 0 \ --use-averaged-model 0 \
@ -274,7 +274,7 @@ log "Export via torch.jit.script()"
log "Test exporting to ONNX format" log "Test exporting to ONNX format"
./pruned_transducer_stateless5/export-onnx.py \ ./pruned_transducer_stateless5/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--use-averaged-model 0 \ --use-averaged-model 0 \
@ -384,7 +384,7 @@ popd
log "Test exporting to ONNX format" log "Test exporting to ONNX format"
./conv_emformer_transducer_stateless2/export-onnx.py \ ./conv_emformer_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
@ -424,7 +424,7 @@ popd
log "Export via torch.jit.trace()" log "Export via torch.jit.trace()"
./lstm_transducer_stateless2/export.py \ ./lstm_transducer_stateless2/export.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
@ -434,7 +434,7 @@ log "Export via torch.jit.trace()"
log "Test exporting to ONNX format" log "Test exporting to ONNX format"
./lstm_transducer_stateless2/export-onnx.py \ ./lstm_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \

View File

@ -139,8 +139,8 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import k2
import onnxruntime import onnxruntime
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
@ -154,7 +154,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -211,10 +211,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt",
) )
parser.add_argument( parser.add_argument(
@ -675,12 +675,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -410,6 +410,16 @@ def main():
raise ValueError(f"Unsupported decoding method: {params.method}") raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n" s = "\n"
if params.method == "ctc-decoding":
for filename, hyp in zip(params.sound_files, hyps):
words = "".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
elif params.method in [
"1best",
"nbest-rescoring",
"whole-lattice-rescoring",
]:
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) words = " ".join(hyp)
words = words.replace("", " ").strip() words = words.replace("", " ").strip()

View File

@ -274,7 +274,7 @@ def main():
params.update(vars(args)) params.update(vars(args))
token_table = k2.SymbolTable.from_file(params.tokens) token_table = k2.SymbolTable.from_file(params.tokens)
params.vocab_size = num_tokens(token_table) params.vocab_size = num_tokens(token_table) + 1 # +1 for blank
params.blank_id = token_table["<blk>"] params.blank_id = token_table["<blk>"]
assert params.blank_id == 0 assert params.blank_id == 0
@ -429,6 +429,16 @@ def main():
raise ValueError(f"Unsupported decoding method: {params.method}") raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n" s = "\n"
if params.method == "ctc-decoding":
for filename, hyp in zip(params.sound_files, hyps):
words = "".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
elif params.method in [
"1best",
"nbest-rescoring",
"whole-lattice-rescoring",
]:
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) words = " ".join(hyp)
words = words.replace("", " ").strip() words = words.replace("", " ").strip()