Fix zipformer-ctc

This commit is contained in:
pkufool 2023-06-23 21:51:31 +08:00
parent 93dd3f5887
commit 655d170374
5 changed files with 37 additions and 37 deletions

View File

@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav
pushd $repo/exp pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/jit_script_chunk_16_left_128.pt" git lfs pull --include "exp/jit_script_chunk_16_left_128.pt"
git lfs pull --include "exp/pretrained.pt" git lfs pull --include "exp/pretrained.pt"
ln -s pretrained.pt epoch-99.pt ln -s pretrained.pt epoch-99.pt
@ -33,7 +34,7 @@ log "Export to torchscript model"
./zipformer/export.py \ ./zipformer/export.py \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--use-averaged-model false \ --use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--causal 1 \ --causal 1 \
--chunk-size 16 \ --chunk-size 16 \
--left-context-frames 128 \ --left-context-frames 128 \
@ -46,7 +47,7 @@ ls -lh $repo/exp/*.pt
log "Decode with models exported by torch.jit.script()" log "Decode with models exported by torch.jit.script()"
./zipformer/jit_pretrained_streaming.py \ ./zipformer/jit_pretrained_streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \ --nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \
$repo/test_wavs/1089-134686-0001.wav $repo/test_wavs/1089-134686-0001.wav
@ -60,7 +61,7 @@ for method in greedy_search modified_beam_search fast_beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav
pushd $repo/exp pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/jit_script.pt" git lfs pull --include "exp/jit_script.pt"
git lfs pull --include "exp/pretrained.pt" git lfs pull --include "exp/pretrained.pt"
ln -s pretrained.pt epoch-99.pt ln -s pretrained.pt epoch-99.pt
@ -33,7 +34,7 @@ log "Export to torchscript model"
./zipformer/export.py \ ./zipformer/export.py \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--use-averaged-model false \ --use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--jit 1 --jit 1
@ -43,7 +44,7 @@ ls -lh $repo/exp/*.pt
log "Decode with models exported by torch.jit.script()" log "Decode with models exported by torch.jit.script()"
./zipformer/jit_pretrained.py \ ./zipformer/jit_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--nn-model-filename $repo/exp/jit_script.pt \ --nn-model-filename $repo/exp/jit_script.pt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
@ -56,7 +57,7 @@ for method in greedy_search modified_beam_search fast_beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav
pushd $repo/exp pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "data/lang_bpe_500/HLG.pt" git lfs pull --include "data/lang_bpe_500/HLG.pt"
git lfs pull --include "data/lang_bpe_500/L.pt" git lfs pull --include "data/lang_bpe_500/L.pt"
git lfs pull --include "data/lang_bpe_500/LG.pt" git lfs pull --include "data/lang_bpe_500/LG.pt"
@ -40,7 +41,7 @@ log "Export to torchscript model"
--use-transducer 1 \ --use-transducer 1 \
--use-ctc 1 \ --use-ctc 1 \
--use-averaged-model false \ --use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--jit 1 --jit 1

View File

@ -24,7 +24,7 @@ You can generate the checkpoint with the following command:
./zipformer/export.py \ ./zipformer/export.py \
--exp-dir ./zipformer/exp \ --exp-dir ./zipformer/exp \
--use-ctc 1 \ --use-ctc 1 \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 9 \ --avg 9 \
--jit 1 --jit 1
@ -35,7 +35,7 @@ You can generate the checkpoint with the following command:
--exp-dir ./zipformer/exp \ --exp-dir ./zipformer/exp \
--use-ctc 1 \ --use-ctc 1 \
--causal 1 \ --causal 1 \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 9 \ --avg 9 \
--jit 1 --jit 1
@ -45,7 +45,7 @@ Usage of this script:
(1) ctc-decoding (1) ctc-decoding
./zipformer/jit_pretrained_ctc.py \ ./zipformer/jit_pretrained_ctc.py \
--model-filename ./zipformer/exp/jit_script.pt \ --model-filename ./zipformer/exp/jit_script.pt \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method ctc-decoding \ --method ctc-decoding \
--sample-rate 16000 \ --sample-rate 16000 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -91,10 +91,10 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from ctc_decode import get_decoding_params from ctc_decode import get_decoding_params
from export import num_tokens
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params from train import get_params
@ -136,9 +136,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model. help="""Path to tokens.txt.
Used only when method is ctc-decoding. Used only when method is ctc-decoding.
""", """,
) )
@ -149,8 +149,8 @@ def get_parser():
default="1best", default="1best",
help="""Decoding method. help="""Decoding method.
Possible values are: Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a sentence (0) ctc-decoding - Use CTC decoding. It uses a token table,
piece model, i.e., lang_dir/bpe.model, to convert i.e., lang_dir/token.txt, to convert
word pieces to words. It needs neither a lexicon word pieces to words. It needs neither a lexicon
nor an n-gram LM. nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only (1) 1best - Use the best path as decoding output. Only
@ -263,10 +263,8 @@ def main():
params.update(get_decoding_params()) params.update(get_decoding_params())
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() token_table = k2.SymbolTable.from_file(params.tokens)
sp.load(params.bpe_model) params.vocab_size = num_tokens(token_table)
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}") logging.info(f"{params}")
@ -340,8 +338,7 @@ def main():
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores
) )
token_ids = get_texts(best_path) token_ids = get_texts(best_path)
hyps = sp.decode(token_ids) hyps = [[token_table[i] for i in ids] for ids in token_ids]
hyps = [s.split() for s in hyps]
elif params.method in [ elif params.method in [
"1best", "1best",
"nbest-rescoring", "nbest-rescoring",
@ -415,6 +412,7 @@ def main():
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) words = " ".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n" s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)

View File

@ -24,7 +24,7 @@ You can generate the checkpoint with the following command:
./zipformer/export.py \ ./zipformer/export.py \
--exp-dir ./zipformer/exp \ --exp-dir ./zipformer/exp \
--use-ctc 1 \ --use-ctc 1 \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 9 --avg 9
@ -34,7 +34,7 @@ You can generate the checkpoint with the following command:
--exp-dir ./zipformer/exp \ --exp-dir ./zipformer/exp \
--use-ctc 1 \ --use-ctc 1 \
--causal 1 \ --causal 1 \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 9 --avg 9
@ -43,7 +43,7 @@ Usage of this script:
(1) ctc-decoding (1) ctc-decoding
./zipformer/pretrained_ctc.py \ ./zipformer/pretrained_ctc.py \
--checkpoint ./zipformer/exp/pretrained.pt \ --checkpoint ./zipformer/exp/pretrained.pt \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method ctc-decoding \ --method ctc-decoding \
--sample-rate 16000 \ --sample-rate 16000 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -90,12 +90,12 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from ctc_decode import get_decoding_params from ctc_decode import get_decoding_params
from export import num_tokens
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_model from train import add_model_arguments, get_model, get_params
from icefall.decode import ( from icefall.decode import (
get_lattice, get_lattice,
@ -144,9 +144,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model. help="""Path to tokens.txt.
Used only when method is ctc-decoding. Used only when method is ctc-decoding.
""", """,
) )
@ -157,8 +157,8 @@ def get_parser():
default="1best", default="1best",
help="""Decoding method. help="""Decoding method.
Possible values are: Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a sentence (0) ctc-decoding - Use CTC decoding. It uses a token table,
piece model, i.e., lang_dir/bpe.model, to convert i.e., lang_dir/tokens.txt, to convert
word pieces to words. It needs neither a lexicon word pieces to words. It needs neither a lexicon
nor an n-gram LM. nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only (1) 1best - Use the best path as decoding output. Only
@ -273,11 +273,10 @@ def main():
params.update(get_decoding_params()) params.update(get_decoding_params())
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() token_table = k2.SymbolTable.from_file(params.tokens)
sp.load(params.bpe_model) params.vocab_size = num_tokens(token_table)
params.blank_id = token_table["blk"]
params.vocab_size = sp.get_piece_size() assert params.blank_id == 0
params.blank_id = 0
logging.info(f"{params}") logging.info(f"{params}")
@ -358,8 +357,7 @@ def main():
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores
) )
token_ids = get_texts(best_path) token_ids = get_texts(best_path)
hyps = sp.decode(token_ids) hyps = [[token_table[i] for i in ids] for ids in token_ids]
hyps = [s.split() for s in hyps]
elif params.method in [ elif params.method in [
"1best", "1best",
"nbest-rescoring", "nbest-rescoring",
@ -433,6 +431,7 @@ def main():
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) words = " ".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n" s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)