Fix zipformer-ctc

This commit is contained in:
pkufool 2023-06-23 21:51:31 +08:00
parent 93dd3f5887
commit 655d170374
5 changed files with 37 additions and 37 deletions

View File

@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/jit_script_chunk_16_left_128.pt"
git lfs pull --include "exp/pretrained.pt"
ln -s pretrained.pt epoch-99.pt
@ -33,7 +34,7 @@ log "Export to torchscript model"
./zipformer/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
@ -46,7 +47,7 @@ ls -lh $repo/exp/*.pt
log "Decode with models exported by torch.jit.script()"
./zipformer/jit_pretrained_streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \
$repo/test_wavs/1089-134686-0001.wav
@ -60,7 +61,7 @@ for method in greedy_search modified_beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/jit_script.pt"
git lfs pull --include "exp/pretrained.pt"
ln -s pretrained.pt epoch-99.pt
@ -33,7 +34,7 @@ log "Export to torchscript model"
./zipformer/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
@ -43,7 +44,7 @@ ls -lh $repo/exp/*.pt
log "Decode with models exported by torch.jit.script()"
./zipformer/jit_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--nn-model-filename $repo/exp/jit_script.pt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
@ -56,7 +57,7 @@ for method in greedy_search modified_beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "data/lang_bpe_500/HLG.pt"
git lfs pull --include "data/lang_bpe_500/L.pt"
git lfs pull --include "data/lang_bpe_500/LG.pt"
@ -40,7 +41,7 @@ log "Export to torchscript model"
--use-transducer 1 \
--use-ctc 1 \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1

View File

@ -24,7 +24,7 @@ You can generate the checkpoint with the following command:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@ -35,7 +35,7 @@ You can generate the checkpoint with the following command:
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--causal 1 \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@ -45,7 +45,7 @@ Usage of this script:
(1) ctc-decoding
./zipformer/jit_pretrained_ctc.py \
--model-filename ./zipformer/exp/jit_script.pt \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method ctc-decoding \
--sample-rate 16000 \
/path/to/foo.wav \
@ -91,10 +91,10 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from ctc_decode import get_decoding_params
from export import num_tokens
from torch.nn.utils.rnn import pad_sequence
from train import get_params
@ -136,9 +136,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.
help="""Path to tokens.txt.
Used only when method is ctc-decoding.
""",
)
@ -149,8 +149,8 @@ def get_parser():
default="1best",
help="""Decoding method.
Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a sentence
piece model, i.e., lang_dir/bpe.model, to convert
(0) ctc-decoding - Use CTC decoding. It uses a token table,
i.e., lang_dir/token.txt, to convert
word pieces to words. It needs neither a lexicon
nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only
@ -263,10 +263,8 @@ def main():
params.update(get_decoding_params())
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
params.vocab_size = sp.get_piece_size()
token_table = k2.SymbolTable.from_file(params.tokens)
params.vocab_size = num_tokens(token_table)
logging.info(f"{params}")
@ -340,8 +338,7 @@ def main():
lattice=lattice, use_double_scores=params.use_double_scores
)
token_ids = get_texts(best_path)
hyps = sp.decode(token_ids)
hyps = [s.split() for s in hyps]
hyps = [[token_table[i] for i in ids] for ids in token_ids]
elif params.method in [
"1best",
"nbest-rescoring",
@ -415,6 +412,7 @@ def main():
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
logging.info(s)

View File

@ -24,7 +24,7 @@ You can generate the checkpoint with the following command:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
@ -34,7 +34,7 @@ You can generate the checkpoint with the following command:
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--causal 1 \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
@ -43,7 +43,7 @@ Usage of this script:
(1) ctc-decoding
./zipformer/pretrained_ctc.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method ctc-decoding \
--sample-rate 16000 \
/path/to/foo.wav \
@ -90,12 +90,12 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from ctc_decode import get_decoding_params
from export import num_tokens
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_model
from train import add_model_arguments, get_model, get_params
from icefall.decode import (
get_lattice,
@ -144,9 +144,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.
help="""Path to tokens.txt.
Used only when method is ctc-decoding.
""",
)
@ -157,8 +157,8 @@ def get_parser():
default="1best",
help="""Decoding method.
Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a sentence
piece model, i.e., lang_dir/bpe.model, to convert
(0) ctc-decoding - Use CTC decoding. It uses a token table,
i.e., lang_dir/tokens.txt, to convert
word pieces to words. It needs neither a lexicon
nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only
@ -273,11 +273,10 @@ def main():
params.update(get_decoding_params())
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
params.vocab_size = sp.get_piece_size()
params.blank_id = 0
token_table = k2.SymbolTable.from_file(params.tokens)
params.vocab_size = num_tokens(token_table)
params.blank_id = token_table["blk"]
assert params.blank_id == 0
logging.info(f"{params}")
@ -358,8 +357,7 @@ def main():
lattice=lattice, use_double_scores=params.use_double_scores
)
token_ids = get_texts(best_path)
hyps = sp.decode(token_ids)
hyps = [s.split() for s in hyps]
hyps = [[token_table[i] for i in ids] for ids in token_ids]
elif params.method in [
"1best",
"nbest-rescoring",
@ -433,6 +431,7 @@ def main():
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
logging.info(s)