updates for the pruned_transducer_stateless recipes

This commit is contained in:
jinzr 2023-07-24 23:25:40 +08:00
parent c03c011230
commit e0e8db3c91
38 changed files with 487 additions and 401 deletions

View File

@ -28,7 +28,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -36,7 +36,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -49,7 +49,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -35,7 +35,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -48,7 +48,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -30,14 +30,14 @@ popd
log "Export to torchscript model"
./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit-trace 1
@ -74,7 +74,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -32,7 +32,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--num-encoder-layers 18 \
--dim-feedforward 2048 \
--nhead 8 \
@ -51,7 +51,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav \

View File

@ -37,7 +37,7 @@ log "Export to torchscript model"
./pruned_transducer_stateless7_ctc/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
@ -74,7 +74,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -36,7 +36,7 @@ log "Export to torchscript model"
./pruned_transducer_stateless7_ctc_bs/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
@ -72,7 +72,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -85,7 +85,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -37,7 +37,7 @@ log "Export to torchscript model"
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--decode-chunk-len 32 \
--epoch 99 \
--avg 1 \
@ -81,7 +81,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--decode-chunk-len 32 \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
@ -95,7 +95,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--decode-chunk-len 32 \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \

View File

@ -41,7 +41,7 @@ log "Decode with models exported by torch.jit.script()"
log "Export to torchscript model"
./pruned_transducer_stateless8/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model false \
--epoch 99 \
--avg 1 \
@ -65,7 +65,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -78,7 +78,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -32,7 +32,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--simulate-streaming 1 \
--causal-convolution 1 \
$repo/test_wavs/1089-134686-0001.wav \
@ -47,7 +47,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--simulate-streaming 1 \
--causal-convolution 1 \
$repo/test_wavs/1089-134686-0001.wav \

View File

@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless/export.py \
--exp-dir ./pruned_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -47,12 +47,12 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -87,10 +87,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -135,13 +135,13 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size, <blk> is
# defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.streaming_model:
assert params.causal_convolution

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage:
(4) fast beam search
./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -66,7 +66,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -79,7 +78,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -97,9 +96,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -237,13 +236,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.simulate_streaming:
assert (
@ -314,6 +314,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -325,8 +331,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -335,16 +341,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -365,12 +371,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -47,12 +47,12 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -98,10 +98,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -145,12 +145,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.streaming_model:
assert params.causal_convolution

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage:
(4) fast beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -66,7 +66,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -79,7 +78,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -97,9 +96,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -238,13 +237,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.simulate_streaming:
assert (
@ -315,6 +315,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -326,8 +332,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -336,16 +342,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -366,12 +372,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX
./pruned_transducer_stateless3/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 9999 \
--avg 1 \
--exp-dir $repo/exp/
@ -48,8 +48,8 @@ import logging
from pathlib import Path
from typing import Dict, Tuple
import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
@ -59,7 +59,7 @@ from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import setup_logger
from icefall.utils import num_tokens, setup_logger
def get_parser():
@ -105,10 +105,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -393,12 +393,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10 \
--jit 1
@ -44,7 +44,7 @@ It will also generate 3 other files: `encoder_jit_script.pt`,
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10 \
--jit-trace 1
@ -56,7 +56,7 @@ It will generates 3 files: `encoder_jit_trace.pt`,
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -97,14 +97,14 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -150,10 +150,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
@ -342,12 +342,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.streaming_model:
assert params.causal_convolution

View File

@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -29,7 +29,7 @@ Usage of this script:
(1) greedy search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -37,7 +37,7 @@ Usage of this script:
(2) beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage of this script:
(3) modified beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -55,7 +55,7 @@ Usage of this script:
(4) fast beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -75,7 +75,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -88,7 +87,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -106,9 +105,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -247,13 +246,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.simulate_streaming:
assert (
@ -324,6 +324,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -335,8 +341,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -345,16 +351,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -375,12 +381,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless4/export.py \
--exp-dir ./pruned_transducer_stateless4/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -48,7 +48,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -59,7 +59,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -116,10 +116,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -164,12 +164,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.streaming_model:
assert params.causal_convolution

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX
./pruned_transducer_stateless5/export-onnx-streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
@ -58,8 +58,8 @@ import logging
from pathlib import Path
from typing import Dict, Tuple
import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
@ -74,7 +74,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -131,10 +131,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -489,12 +489,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX
./pruned_transducer_stateless5/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
@ -55,8 +55,8 @@ import logging
from pathlib import Path
from typing import Dict, Tuple
import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
@ -71,7 +71,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -128,10 +128,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -416,12 +416,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -48,7 +48,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -59,7 +59,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -116,10 +116,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -164,12 +164,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.streaming_model:
assert params.causal_convolution

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage:
(4) fast beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -66,7 +66,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -79,6 +78,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -95,9 +96,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -214,13 +215,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -275,6 +277,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -286,8 +294,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -296,16 +304,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -326,12 +334,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage:
./pruned_transducer_stateless6/export.py \
--exp-dir ./pruned_transducer_stateless6/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -47,12 +47,12 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -98,10 +98,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -135,12 +135,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless7_ctc/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless7_ctc/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -86,7 +86,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -97,7 +97,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -154,10 +154,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -197,12 +197,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless7_ctc/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -29,7 +29,7 @@ Usage of this script:
(1) greedy search
./pruned_transducer_stateless7_ctc/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -37,7 +37,7 @@ Usage of this script:
(2) beam search
./pruned_transducer_stateless7_ctc/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage of this script:
(3) modified beam search
./pruned_transducer_stateless7_ctc/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -55,7 +55,7 @@ Usage of this script:
(4) fast beam search
./pruned_transducer_stateless7_ctc/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -75,7 +75,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -88,6 +87,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -104,9 +105,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -223,13 +224,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -284,6 +286,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -295,8 +303,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -305,16 +313,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -335,12 +343,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,14 +22,14 @@ You can use the following command to get the exported models:
./pruned_transducer_stateless7_ctc/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
Usage of this script:
(1) ctc-decoding
./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--bpe-model data/lang_bpe_500/bpe.model \
--method ctc-decoding \
@ -38,7 +38,7 @@ Usage of this script:
/path/to/bar.wav
(2) 1best
./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
@ -48,7 +48,7 @@ Usage of this script:
/path/to/bar.wav
(3) nbest-rescoring
./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
./bruned_transducer_stateless7_ctc/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
@ -60,7 +60,7 @@ Usage of this script:
(4) whole-lattice-rescoring
./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \

View File

@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless7_ctc_bs/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 13 \
--jit 1
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless7_ctc_bs/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 13
@ -86,7 +86,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -97,7 +97,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -154,10 +154,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -197,12 +197,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -28,7 +28,7 @@ Usage:
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 13 \
--onnx 1
@ -48,7 +48,7 @@ Check `onnx_check.py` for how to use them.
(2) Export to ONNX format which can be used in Triton Server
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 13 \
--onnx-triton 1
@ -86,9 +86,10 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -98,8 +99,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -156,10 +156,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -727,13 +727,15 @@ def main():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless7_ctc_bs/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 13
@ -29,7 +29,7 @@ Usage of this script:
(1) greedy search
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -37,7 +37,7 @@ Usage of this script:
(2) beam search
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage of this script:
(3) modified beam search
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -55,7 +55,7 @@ Usage of this script:
(4) fast beam search
./pruned_transducer_stateless7_ctc_bs/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -75,7 +75,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -88,6 +87,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -104,9 +105,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -223,13 +224,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -284,6 +286,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -295,8 +303,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -305,16 +313,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -335,12 +343,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,14 +22,14 @@ You can use the following command to get the exported models:
./pruned_transducer_stateless7_ctc_bs/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
Usage of this script:
(1) ctc-decoding
./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \
./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--bpe-model data/lang_bpe_500/bpe.model \
--method ctc-decoding \
@ -38,7 +38,7 @@ Usage of this script:
/path/to/bar.wav
(2) 1best
./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \
./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
@ -48,7 +48,7 @@ Usage of this script:
/path/to/bar.wav
(3) nbest-rescoring
./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \
./bruned_transducer_stateless7_ctc/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
@ -60,7 +60,7 @@ Usage of this script:
(4) whole-lattice-rescoring
./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \
./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \

View File

@ -66,6 +66,7 @@ import argparse
import logging
from pathlib import Path
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model
@ -76,8 +77,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -123,10 +123,10 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="The lang dir",
default="data/lang_char/tokens.txt",
help="The tokens.txt file",
)
parser.add_argument(
@ -246,9 +246,14 @@ def main():
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -28,7 +28,7 @@ popd
2. Export to ncnn
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
@ -64,7 +64,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model
@ -75,7 +75,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -121,10 +121,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -244,12 +244,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -29,7 +29,7 @@ popd
2. Export the model to ONNX
./pruned_transducer_stateless7_streaming/export-onnx-zh.py \
--lang-dir $repo/data/lang_char_bpe \
--tokens $repo/data/lang_char_bpe/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -60,6 +60,7 @@ import logging
from pathlib import Path
from typing import Dict, List, Tuple
import k2
import onnx
import torch
import torch.nn as nn
@ -76,8 +77,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -134,10 +134,10 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
"--tokens",
type=str,
default="data/lang_char",
help="The lang dir",
default="data/lang_char/tokens.txt",
help="The tokens.txt file",
)
parser.add_argument(
@ -493,9 +493,14 @@ def main():
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -27,7 +27,7 @@ popd
2. Export the model to ONNX
./pruned_transducer_stateless7_streaming/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -48,8 +48,8 @@ import logging
from pathlib import Path
from typing import Dict, List, Tuple
import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
@ -65,7 +65,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -122,10 +122,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -481,12 +481,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -29,7 +29,7 @@ Usage of this script:
(1) greedy search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -37,7 +37,7 @@ Usage of this script:
(2) beam search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage of this script:
(3) modified beam search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -55,7 +55,7 @@ Usage of this script:
(4) fast beam search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -75,7 +75,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -88,7 +87,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -106,9 +105,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -225,13 +224,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -286,6 +286,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -297,8 +303,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -307,16 +313,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -337,12 +343,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -28,7 +28,7 @@ popd
2. Export to ncnn
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
@ -64,7 +64,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model
@ -75,7 +75,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -121,10 +121,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -244,12 +244,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless8/export.py \
--exp-dir ./pruned_transducer_stateless8/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless8/export.py \
--exp-dir ./pruned_transducer_stateless8/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -86,7 +86,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
@ -98,7 +98,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -155,10 +155,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -198,12 +198,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless8/export.py \
--exp-dir ./pruned_transducer_stateless8/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -29,7 +29,7 @@ Usage of this script:
(1) greedy search
./pruned_transducer_stateless8/pretrained.py \
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -37,7 +37,7 @@ Usage of this script:
(2) beam search
./pruned_transducer_stateless8/pretrained.py \
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage of this script:
(3) modified beam search
./pruned_transducer_stateless8/pretrained.py \
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -55,7 +55,7 @@ Usage of this script:
(4) fast beam search
./pruned_transducer_stateless8/pretrained.py \
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -75,7 +75,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -88,7 +87,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -106,9 +105,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -225,13 +224,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -286,6 +286,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -297,8 +303,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -307,16 +313,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -337,12 +343,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")