mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix zipformer-ctc
This commit is contained in:
parent
93dd3f5887
commit
655d170374
@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav
|
|||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "data/lang_bpe_500/tokens.txt"
|
||||||
git lfs pull --include "exp/jit_script_chunk_16_left_128.pt"
|
git lfs pull --include "exp/jit_script_chunk_16_left_128.pt"
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
ln -s pretrained.pt epoch-99.pt
|
ln -s pretrained.pt epoch-99.pt
|
||||||
@ -33,7 +34,7 @@ log "Export to torchscript model"
|
|||||||
./zipformer/export.py \
|
./zipformer/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--causal 1 \
|
--causal 1 \
|
||||||
--chunk-size 16 \
|
--chunk-size 16 \
|
||||||
--left-context-frames 128 \
|
--left-context-frames 128 \
|
||||||
@ -46,7 +47,7 @@ ls -lh $repo/exp/*.pt
|
|||||||
log "Decode with models exported by torch.jit.script()"
|
log "Decode with models exported by torch.jit.script()"
|
||||||
|
|
||||||
./zipformer/jit_pretrained_streaming.py \
|
./zipformer/jit_pretrained_streaming.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \
|
--nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav
|
$repo/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
@ -60,7 +61,7 @@ for method in greedy_search modified_beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav
|
|||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "data/lang_bpe_500/tokens.txt"
|
||||||
git lfs pull --include "exp/jit_script.pt"
|
git lfs pull --include "exp/jit_script.pt"
|
||||||
git lfs pull --include "exp/pretrained.pt"
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
ln -s pretrained.pt epoch-99.pt
|
ln -s pretrained.pt epoch-99.pt
|
||||||
@ -33,7 +34,7 @@ log "Export to torchscript model"
|
|||||||
./zipformer/export.py \
|
./zipformer/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -43,7 +44,7 @@ ls -lh $repo/exp/*.pt
|
|||||||
log "Decode with models exported by torch.jit.script()"
|
log "Decode with models exported by torch.jit.script()"
|
||||||
|
|
||||||
./zipformer/jit_pretrained.py \
|
./zipformer/jit_pretrained.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--nn-model-filename $repo/exp/jit_script.pt \
|
--nn-model-filename $repo/exp/jit_script.pt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
@ -56,7 +57,7 @@ for method in greedy_search modified_beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav
|
|||||||
|
|
||||||
pushd $repo/exp
|
pushd $repo/exp
|
||||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "data/lang_bpe_500/tokens.txt"
|
||||||
git lfs pull --include "data/lang_bpe_500/HLG.pt"
|
git lfs pull --include "data/lang_bpe_500/HLG.pt"
|
||||||
git lfs pull --include "data/lang_bpe_500/L.pt"
|
git lfs pull --include "data/lang_bpe_500/L.pt"
|
||||||
git lfs pull --include "data/lang_bpe_500/LG.pt"
|
git lfs pull --include "data/lang_bpe_500/LG.pt"
|
||||||
@ -40,7 +41,7 @@ log "Export to torchscript model"
|
|||||||
--use-transducer 1 \
|
--use-transducer 1 \
|
||||||
--use-ctc 1 \
|
--use-ctc 1 \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
|
|||||||
@ -24,7 +24,7 @@ You can generate the checkpoint with the following command:
|
|||||||
./zipformer/export.py \
|
./zipformer/export.py \
|
||||||
--exp-dir ./zipformer/exp \
|
--exp-dir ./zipformer/exp \
|
||||||
--use-ctc 1 \
|
--use-ctc 1 \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 9 \
|
--avg 9 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -35,7 +35,7 @@ You can generate the checkpoint with the following command:
|
|||||||
--exp-dir ./zipformer/exp \
|
--exp-dir ./zipformer/exp \
|
||||||
--use-ctc 1 \
|
--use-ctc 1 \
|
||||||
--causal 1 \
|
--causal 1 \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 9 \
|
--avg 9 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -45,7 +45,7 @@ Usage of this script:
|
|||||||
(1) ctc-decoding
|
(1) ctc-decoding
|
||||||
./zipformer/jit_pretrained_ctc.py \
|
./zipformer/jit_pretrained_ctc.py \
|
||||||
--model-filename ./zipformer/exp/jit_script.pt \
|
--model-filename ./zipformer/exp/jit_script.pt \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method ctc-decoding \
|
--method ctc-decoding \
|
||||||
--sample-rate 16000 \
|
--sample-rate 16000 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -91,10 +91,10 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from ctc_decode import get_decoding_params
|
from ctc_decode import get_decoding_params
|
||||||
|
from export import num_tokens
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import get_params
|
from train import get_params
|
||||||
|
|
||||||
@ -136,9 +136,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.
|
help="""Path to tokens.txt.
|
||||||
Used only when method is ctc-decoding.
|
Used only when method is ctc-decoding.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -149,8 +149,8 @@ def get_parser():
|
|||||||
default="1best",
|
default="1best",
|
||||||
help="""Decoding method.
|
help="""Decoding method.
|
||||||
Possible values are:
|
Possible values are:
|
||||||
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
(0) ctc-decoding - Use CTC decoding. It uses a token table,
|
||||||
piece model, i.e., lang_dir/bpe.model, to convert
|
i.e., lang_dir/token.txt, to convert
|
||||||
word pieces to words. It needs neither a lexicon
|
word pieces to words. It needs neither a lexicon
|
||||||
nor an n-gram LM.
|
nor an n-gram LM.
|
||||||
(1) 1best - Use the best path as decoding output. Only
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
@ -263,10 +263,8 @@ def main():
|
|||||||
params.update(get_decoding_params())
|
params.update(get_decoding_params())
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
sp.load(params.bpe_model)
|
params.vocab_size = num_tokens(token_table)
|
||||||
|
|
||||||
params.vocab_size = sp.get_piece_size()
|
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -340,8 +338,7 @@ def main():
|
|||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
token_ids = get_texts(best_path)
|
token_ids = get_texts(best_path)
|
||||||
hyps = sp.decode(token_ids)
|
hyps = [[token_table[i] for i in ids] for ids in token_ids]
|
||||||
hyps = [s.split() for s in hyps]
|
|
||||||
elif params.method in [
|
elif params.method in [
|
||||||
"1best",
|
"1best",
|
||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
@ -415,6 +412,7 @@ def main():
|
|||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
words = " ".join(hyp)
|
||||||
|
words = words.replace("▁", " ").strip()
|
||||||
s += f"{filename}:\n{words}\n\n"
|
s += f"{filename}:\n{words}\n\n"
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,7 @@ You can generate the checkpoint with the following command:
|
|||||||
./zipformer/export.py \
|
./zipformer/export.py \
|
||||||
--exp-dir ./zipformer/exp \
|
--exp-dir ./zipformer/exp \
|
||||||
--use-ctc 1 \
|
--use-ctc 1 \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 9
|
--avg 9
|
||||||
|
|
||||||
@ -34,7 +34,7 @@ You can generate the checkpoint with the following command:
|
|||||||
--exp-dir ./zipformer/exp \
|
--exp-dir ./zipformer/exp \
|
||||||
--use-ctc 1 \
|
--use-ctc 1 \
|
||||||
--causal 1 \
|
--causal 1 \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 9
|
--avg 9
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ Usage of this script:
|
|||||||
(1) ctc-decoding
|
(1) ctc-decoding
|
||||||
./zipformer/pretrained_ctc.py \
|
./zipformer/pretrained_ctc.py \
|
||||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method ctc-decoding \
|
--method ctc-decoding \
|
||||||
--sample-rate 16000 \
|
--sample-rate 16000 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -90,12 +90,12 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from ctc_decode import get_decoding_params
|
from ctc_decode import get_decoding_params
|
||||||
|
from export import num_tokens
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
get_lattice,
|
get_lattice,
|
||||||
@ -144,9 +144,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.
|
help="""Path to tokens.txt.
|
||||||
Used only when method is ctc-decoding.
|
Used only when method is ctc-decoding.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -157,8 +157,8 @@ def get_parser():
|
|||||||
default="1best",
|
default="1best",
|
||||||
help="""Decoding method.
|
help="""Decoding method.
|
||||||
Possible values are:
|
Possible values are:
|
||||||
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
(0) ctc-decoding - Use CTC decoding. It uses a token table,
|
||||||
piece model, i.e., lang_dir/bpe.model, to convert
|
i.e., lang_dir/tokens.txt, to convert
|
||||||
word pieces to words. It needs neither a lexicon
|
word pieces to words. It needs neither a lexicon
|
||||||
nor an n-gram LM.
|
nor an n-gram LM.
|
||||||
(1) 1best - Use the best path as decoding output. Only
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
@ -273,11 +273,10 @@ def main():
|
|||||||
params.update(get_decoding_params())
|
params.update(get_decoding_params())
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
sp.load(params.bpe_model)
|
params.vocab_size = num_tokens(token_table)
|
||||||
|
params.blank_id = token_table["blk"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
assert params.blank_id == 0
|
||||||
params.blank_id = 0
|
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -358,8 +357,7 @@ def main():
|
|||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
token_ids = get_texts(best_path)
|
token_ids = get_texts(best_path)
|
||||||
hyps = sp.decode(token_ids)
|
hyps = [[token_table[i] for i in ids] for ids in token_ids]
|
||||||
hyps = [s.split() for s in hyps]
|
|
||||||
elif params.method in [
|
elif params.method in [
|
||||||
"1best",
|
"1best",
|
||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
@ -433,6 +431,7 @@ def main():
|
|||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
words = " ".join(hyp)
|
||||||
|
words = words.replace("▁", " ").strip()
|
||||||
s += f"{filename}:\n{words}\n\n"
|
s += f"{filename}:\n{words}\n\n"
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user