updates for the pruned_transducer_stateless recipes

This commit is contained in:
jinzr 2023-07-24 23:25:40 +08:00
parent bbeca5ccd4
commit f13f0b990e
38 changed files with 487 additions and 401 deletions

View File

@ -28,7 +28,7 @@ for sym in 1 2 3; do
--method greedy_search \ --method greedy_search \
--max-sym-per-frame $sym \ --max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -36,7 +36,7 @@ for sym in 1 2 3; do
--method greedy_search \ --method greedy_search \
--max-sym-per-frame $sym \ --max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
@ -49,7 +49,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -35,7 +35,7 @@ for sym in 1 2 3; do
--method greedy_search \ --method greedy_search \
--max-sym-per-frame $sym \ --max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
@ -48,7 +48,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -30,14 +30,14 @@ popd
log "Export to torchscript model" log "Export to torchscript model"
./pruned_transducer_stateless3/export.py \ ./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--jit 1 --jit 1
./pruned_transducer_stateless3/export.py \ ./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--jit-trace 1 --jit-trace 1
@ -74,7 +74,7 @@ for sym in 1 2 3; do
--method greedy_search \ --method greedy_search \
--max-sym-per-frame $sym \ --max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -32,7 +32,7 @@ for sym in 1 2 3; do
--method greedy_search \ --method greedy_search \
--max-sym-per-frame $sym \ --max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--num-encoder-layers 18 \ --num-encoder-layers 18 \
--dim-feedforward 2048 \ --dim-feedforward 2048 \
--nhead 8 \ --nhead 8 \
@ -51,7 +51,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav \ $repo/test_wavs/1221-135766-0002.wav \

View File

@ -37,7 +37,7 @@ log "Export to torchscript model"
./pruned_transducer_stateless7_ctc/export.py \ ./pruned_transducer_stateless7_ctc/export.py \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--use-averaged-model false \ --use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--jit 1 --jit 1
@ -74,7 +74,7 @@ for sym in 1 2 3; do
--method greedy_search \ --method greedy_search \
--max-sym-per-frame $sym \ --max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -36,7 +36,7 @@ log "Export to torchscript model"
./pruned_transducer_stateless7_ctc_bs/export.py \ ./pruned_transducer_stateless7_ctc_bs/export.py \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--use-averaged-model false \ --use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--jit 1 --jit 1
@ -72,7 +72,7 @@ for sym in 1 2 3; do
--method greedy_search \ --method greedy_search \
--max-sym-per-frame $sym \ --max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
@ -85,7 +85,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -37,7 +37,7 @@ log "Export to torchscript model"
./pruned_transducer_stateless7_streaming/export.py \ ./pruned_transducer_stateless7_streaming/export.py \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--use-averaged-model false \ --use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--decode-chunk-len 32 \ --decode-chunk-len 32 \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
@ -81,7 +81,7 @@ for sym in 1 2 3; do
--method greedy_search \ --method greedy_search \
--max-sym-per-frame $sym \ --max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--decode-chunk-len 32 \ --decode-chunk-len 32 \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
@ -95,7 +95,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--decode-chunk-len 32 \ --decode-chunk-len 32 \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \

View File

@ -41,7 +41,7 @@ log "Decode with models exported by torch.jit.script()"
log "Export to torchscript model" log "Export to torchscript model"
./pruned_transducer_stateless8/export.py \ ./pruned_transducer_stateless8/export.py \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model false \ --use-averaged-model false \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
@ -65,7 +65,7 @@ for sym in 1 2 3; do
--method greedy_search \ --method greedy_search \
--max-sym-per-frame $sym \ --max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
@ -78,7 +78,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -32,7 +32,7 @@ for sym in 1 2 3; do
--method greedy_search \ --method greedy_search \
--max-sym-per-frame $sym \ --max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--simulate-streaming 1 \ --simulate-streaming 1 \
--causal-convolution 1 \ --causal-convolution 1 \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
@ -47,7 +47,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--simulate-streaming 1 \ --simulate-streaming 1 \
--causal-convolution 1 \ --causal-convolution 1 \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \

View File

@ -22,7 +22,7 @@
Usage: Usage:
./pruned_transducer_stateless/export.py \ ./pruned_transducer_stateless/export.py \
--exp-dir ./pruned_transducer_stateless/exp \ --exp-dir ./pruned_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -47,12 +47,12 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -87,10 +87,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -135,13 +135,13 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py # Load id of the <blk> token and the vocab size, <blk> is
params.blank_id = sp.piece_to_id("<blk>") # defined in local/train_bpe_model.py
params.unk_id = sp.piece_to_id("<unk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.streaming_model: if params.streaming_model:
assert params.causal_convolution assert params.causal_convolution

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless/pretrained.py \ ./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search (2) beam search
./pruned_transducer_stateless/pretrained.py \ ./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless/pretrained.py \ ./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -46,7 +46,7 @@ Usage:
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless/pretrained.py \ ./pruned_transducer_stateless/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -66,7 +66,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -79,7 +78,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -97,9 +96,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to tokens.txt.""",
) )
parser.add_argument( parser.add_argument(
@ -237,13 +236,14 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = token_table["<unk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.simulate_streaming: if params.simulate_streaming:
assert ( assert (
@ -314,6 +314,12 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -325,8 +331,8 @@ def main():
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -335,16 +341,16 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1: elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
else: else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
@ -365,12 +371,11 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append(token_ids_to_words(hyp))
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage: Usage:
./pruned_transducer_stateless2/export.py \ ./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -47,12 +47,12 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -98,10 +98,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -145,12 +145,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.streaming_model: if params.streaming_model:
assert params.causal_convolution assert params.causal_convolution

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless2/pretrained.py \ ./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search (2) beam search
./pruned_transducer_stateless2/pretrained.py \ ./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless2/pretrained.py \ ./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -46,7 +46,7 @@ Usage:
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless2/pretrained.py \ ./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -66,7 +66,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -79,7 +78,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -97,9 +96,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to tokens.txt.""",
) )
parser.add_argument( parser.add_argument(
@ -238,13 +237,14 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = token_table["<unk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.simulate_streaming: if params.simulate_streaming:
assert ( assert (
@ -315,6 +315,12 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -326,8 +332,8 @@ def main():
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -336,16 +342,16 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1: elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
else: else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
@ -366,12 +372,11 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append(token_ids_to_words(hyp))
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX 2. Export the model to ONNX
./pruned_transducer_stateless3/export-onnx.py \ ./pruned_transducer_stateless3/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 9999 \ --epoch 9999 \
--avg 1 \ --avg 1 \
--exp-dir $repo/exp/ --exp-dir $repo/exp/
@ -48,8 +48,8 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Tuple from typing import Dict, Tuple
import k2
import onnx import onnx
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from conformer import Conformer from conformer import Conformer
@ -59,7 +59,7 @@ from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import setup_logger from icefall.utils import num_tokens, setup_logger
def get_parser(): def get_parser():
@ -105,10 +105,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -393,12 +393,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless3/export.py \ ./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 \ --avg 10 \
--jit 1 --jit 1
@ -44,7 +44,7 @@ It will also generate 3 other files: `encoder_jit_script.pt`,
./pruned_transducer_stateless3/export.py \ ./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 \ --avg 10 \
--jit-trace 1 --jit-trace 1
@ -56,7 +56,7 @@ It will generates 3 files: `encoder_jit_trace.pt`,
./pruned_transducer_stateless3/export.py \ ./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -97,14 +97,14 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -150,10 +150,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt",
) )
parser.add_argument( parser.add_argument(
@ -342,12 +342,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.streaming_model: if params.streaming_model:
assert params.causal_convolution assert params.causal_convolution

View File

@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless3/export.py \ ./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \ --exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -29,7 +29,7 @@ Usage of this script:
(1) greedy search (1) greedy search
./pruned_transducer_stateless3/pretrained.py \ ./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
@ -37,7 +37,7 @@ Usage of this script:
(2) beam search (2) beam search
./pruned_transducer_stateless3/pretrained.py \ ./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -46,7 +46,7 @@ Usage of this script:
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless3/pretrained.py \ ./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -55,7 +55,7 @@ Usage of this script:
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless3/pretrained.py \ ./pruned_transducer_stateless3/pretrained.py \
--checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -75,7 +75,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -88,7 +87,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -106,9 +105,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to tokens.txt.""",
) )
parser.add_argument( parser.add_argument(
@ -247,13 +246,14 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = token_table["<unk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.simulate_streaming: if params.simulate_streaming:
assert ( assert (
@ -324,6 +324,12 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -335,8 +341,8 @@ def main():
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -345,16 +351,16 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1: elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
else: else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
@ -375,12 +381,11 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append(token_ids_to_words(hyp))
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage: Usage:
./pruned_transducer_stateless4/export.py \ ./pruned_transducer_stateless4/export.py \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless4/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -48,7 +48,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -59,7 +59,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -116,10 +116,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -164,12 +164,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.streaming_model: if params.streaming_model:
assert params.causal_convolution assert params.causal_convolution

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX 2. Export the model to ONNX
./pruned_transducer_stateless5/export-onnx-streaming.py \ ./pruned_transducer_stateless5/export-onnx-streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--use-averaged-model 0 \ --use-averaged-model 0 \
@ -58,8 +58,8 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Tuple from typing import Dict, Tuple
import k2
import onnx import onnx
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from conformer import Conformer from conformer import Conformer
@ -74,7 +74,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import setup_logger, str2bool from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser(): def get_parser():
@ -131,10 +131,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -489,12 +489,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX 2. Export the model to ONNX
./pruned_transducer_stateless5/export-onnx.py \ ./pruned_transducer_stateless5/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--use-averaged-model 0 \ --use-averaged-model 0 \
@ -55,8 +55,8 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Tuple from typing import Dict, Tuple
import k2
import onnx import onnx
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from conformer import Conformer from conformer import Conformer
@ -71,7 +71,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import setup_logger, str2bool from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser(): def get_parser():
@ -128,10 +128,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -416,12 +416,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -22,7 +22,7 @@
Usage: Usage:
./pruned_transducer_stateless5/export.py \ ./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -48,7 +48,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -59,7 +59,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -116,10 +116,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -164,12 +164,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
if params.streaming_model: if params.streaming_model:
assert params.causal_convolution assert params.causal_convolution

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless5/pretrained.py \ ./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search (2) beam search
./pruned_transducer_stateless5/pretrained.py \ ./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless5/pretrained.py \ ./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -46,7 +46,7 @@ Usage:
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless5/pretrained.py \ ./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -66,7 +66,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -79,6 +78,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -95,9 +96,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to tokens.txt.""",
) )
parser.add_argument( parser.add_argument(
@ -214,13 +215,14 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = token_table["<unk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}") logging.info(f"{params}")
@ -275,6 +277,12 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -286,8 +294,8 @@ def main():
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -296,16 +304,16 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1: elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
else: else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
@ -326,12 +334,11 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append(token_ids_to_words(hyp))
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage: Usage:
./pruned_transducer_stateless6/export.py \ ./pruned_transducer_stateless6/export.py \
--exp-dir ./pruned_transducer_stateless6/exp \ --exp-dir ./pruned_transducer_stateless6/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -47,12 +47,12 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -98,10 +98,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -135,12 +135,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless7_ctc/export.py \ ./pruned_transducer_stateless7_ctc/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc/exp \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 9 \ --avg 9 \
--jit 1 --jit 1
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless7_ctc/export.py \ ./pruned_transducer_stateless7_ctc/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc/exp \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -86,7 +86,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -97,7 +97,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -154,10 +154,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -197,12 +197,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless7_ctc/export.py \ ./pruned_transducer_stateless7_ctc/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc/exp \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -29,7 +29,7 @@ Usage of this script:
(1) greedy search (1) greedy search
./pruned_transducer_stateless7_ctc/pretrained.py \ ./pruned_transducer_stateless7_ctc/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
@ -37,7 +37,7 @@ Usage of this script:
(2) beam search (2) beam search
./pruned_transducer_stateless7_ctc/pretrained.py \ ./pruned_transducer_stateless7_ctc/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -46,7 +46,7 @@ Usage of this script:
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless7_ctc/pretrained.py \ ./pruned_transducer_stateless7_ctc/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -55,7 +55,7 @@ Usage of this script:
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless7_ctc/pretrained.py \ ./pruned_transducer_stateless7_ctc/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -75,7 +75,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -88,6 +87,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -104,9 +105,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to tokens.txt.""",
) )
parser.add_argument( parser.add_argument(
@ -223,13 +224,14 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = token_table["<unk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}") logging.info(f"{params}")
@ -284,6 +286,12 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -295,8 +303,8 @@ def main():
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -305,16 +313,16 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1: elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
else: else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
@ -335,12 +343,11 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append(token_ids_to_words(hyp))
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -22,14 +22,14 @@ You can use the following command to get the exported models:
./pruned_transducer_stateless7_ctc/export.py \ ./pruned_transducer_stateless7_ctc/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc/exp \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
Usage of this script: Usage of this script:
(1) ctc-decoding (1) ctc-decoding
./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ ./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--bpe-model data/lang_bpe_500/bpe.model \ --bpe-model data/lang_bpe_500/bpe.model \
--method ctc-decoding \ --method ctc-decoding \
@ -38,7 +38,7 @@ Usage of this script:
/path/to/bar.wav /path/to/bar.wav
(2) 1best (2) 1best
./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ ./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \ --HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \ --words-file data/lang_bpe_500/words.txt \
@ -48,7 +48,7 @@ Usage of this script:
/path/to/bar.wav /path/to/bar.wav
(3) nbest-rescoring (3) nbest-rescoring
./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ ./bruned_transducer_stateless7_ctc/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \ --HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \ --words-file data/lang_bpe_500/words.txt \
@ -60,7 +60,7 @@ Usage of this script:
(4) whole-lattice-rescoring (4) whole-lattice-rescoring
./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ ./pruned_transducer_stateless7_ctc/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \ --HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \ --words-file data/lang_bpe_500/words.txt \

View File

@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless7_ctc_bs/export.py \ ./pruned_transducer_stateless7_ctc_bs/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 13 \ --avg 13 \
--jit 1 --jit 1
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless7_ctc_bs/export.py \ ./pruned_transducer_stateless7_ctc_bs/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 13 --avg 13
@ -86,7 +86,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -97,7 +97,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -154,10 +154,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -197,12 +197,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -28,7 +28,7 @@ Usage:
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ ./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 13 \ --avg 13 \
--onnx 1 --onnx 1
@ -48,7 +48,7 @@ Check `onnx_check.py` for how to use them.
(2) Export to ONNX format which can be used in Triton Server (2) Export to ONNX format which can be used in Triton Server
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ ./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 13 \ --avg 13 \
--onnx-triton 1 --onnx-triton 1
@ -86,9 +86,10 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -98,8 +99,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv
def get_parser(): def get_parser():
@ -156,10 +156,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -728,12 +728,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless7_ctc_bs/export.py \ ./pruned_transducer_stateless7_ctc_bs/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 13 --avg 13
@ -29,7 +29,7 @@ Usage of this script:
(1) greedy search (1) greedy search
./pruned_transducer_stateless7_ctc_bs/pretrained.py \ ./pruned_transducer_stateless7_ctc_bs/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
@ -37,7 +37,7 @@ Usage of this script:
(2) beam search (2) beam search
./pruned_transducer_stateless7_ctc_bs/pretrained.py \ ./pruned_transducer_stateless7_ctc_bs/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -46,7 +46,7 @@ Usage of this script:
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless7_ctc_bs/pretrained.py \ ./pruned_transducer_stateless7_ctc_bs/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -55,7 +55,7 @@ Usage of this script:
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless7_ctc_bs/pretrained.py \ ./pruned_transducer_stateless7_ctc_bs/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -75,7 +75,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -88,6 +87,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -104,9 +105,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to tokens.txt.""",
) )
parser.add_argument( parser.add_argument(
@ -223,13 +224,14 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = token_table["<unk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}") logging.info(f"{params}")
@ -284,6 +286,12 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -295,8 +303,8 @@ def main():
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -305,16 +313,16 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1: elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
else: else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
@ -335,12 +343,11 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append(token_ids_to_words(hyp))
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -22,14 +22,14 @@ You can use the following command to get the exported models:
./pruned_transducer_stateless7_ctc_bs/export.py \ ./pruned_transducer_stateless7_ctc_bs/export.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
Usage of this script: Usage of this script:
(1) ctc-decoding (1) ctc-decoding
./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ ./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--bpe-model data/lang_bpe_500/bpe.model \ --bpe-model data/lang_bpe_500/bpe.model \
--method ctc-decoding \ --method ctc-decoding \
@ -38,7 +38,7 @@ Usage of this script:
/path/to/bar.wav /path/to/bar.wav
(2) 1best (2) 1best
./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ ./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \ --HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \ --words-file data/lang_bpe_500/words.txt \
@ -48,7 +48,7 @@ Usage of this script:
/path/to/bar.wav /path/to/bar.wav
(3) nbest-rescoring (3) nbest-rescoring
./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ ./bruned_transducer_stateless7_ctc/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \ --HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \ --words-file data/lang_bpe_500/words.txt \
@ -60,7 +60,7 @@ Usage of this script:
(4) whole-lattice-rescoring (4) whole-lattice-rescoring
./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ ./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \
--checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \ --HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \ --words-file data/lang_bpe_500/words.txt \

View File

@ -66,6 +66,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import k2
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model from train2 import add_model_arguments, get_params, get_transducer_model
@ -76,8 +77,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon from icefall.utils import num_tokens, setup_logger, str2bool
from icefall.utils import setup_logger, str2bool
def get_parser(): def get_parser():
@ -123,10 +123,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--tokens",
type=str, type=str,
default="data/lang_char", default="data/lang_char/tokens.txt",
help="The lang dir", help="The tokens.txt file",
) )
parser.add_argument( parser.add_argument(
@ -246,9 +246,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir) # Load tokens.txt here
params.blank_id = 0 token_table = k2.SymbolTable.from_file(params.tokens)
params.vocab_size = max(lexicon.tokens) + 1
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -28,7 +28,7 @@ popd
2. Export to ncnn 2. Export to ncnn
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
@ -64,7 +64,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model from train2 import add_model_arguments, get_params, get_transducer_model
@ -75,7 +75,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import setup_logger, str2bool from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser(): def get_parser():
@ -121,10 +121,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -244,12 +244,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -29,7 +29,7 @@ popd
2. Export the model to ONNX 2. Export the model to ONNX
./pruned_transducer_stateless7_streaming/export-onnx-zh.py \ ./pruned_transducer_stateless7_streaming/export-onnx-zh.py \
--lang-dir $repo/data/lang_char_bpe \ --tokens $repo/data/lang_char_bpe/tokens.txt \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
@ -60,6 +60,7 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import k2
import onnx import onnx
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -76,8 +77,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon from icefall.utils import num_tokens, setup_logger, str2bool
from icefall.utils import setup_logger, str2bool
def get_parser(): def get_parser():
@ -134,10 +134,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--tokens",
type=str, type=str,
default="data/lang_char", default="data/lang_char/tokens.txt",
help="The lang dir", help="The tokens.txt file",
) )
parser.add_argument( parser.add_argument(
@ -493,9 +493,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir) # Load tokens.txt here
params.blank_id = 0 token_table = k2.SymbolTable.from_file(params.tokens)
params.vocab_size = max(lexicon.tokens) + 1
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -27,7 +27,7 @@ popd
2. Export the model to ONNX 2. Export the model to ONNX
./pruned_transducer_stateless7_streaming/export-onnx.py \ ./pruned_transducer_stateless7_streaming/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
@ -48,8 +48,8 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import k2
import onnx import onnx
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from decoder import Decoder from decoder import Decoder
@ -65,7 +65,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import setup_logger, str2bool from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser(): def get_parser():
@ -122,10 +122,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -481,12 +481,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless7_streaming/export.py \ ./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -29,7 +29,7 @@ Usage of this script:
(1) greedy search (1) greedy search
./pruned_transducer_stateless7_streaming/pretrained.py \ ./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
@ -37,7 +37,7 @@ Usage of this script:
(2) beam search (2) beam search
./pruned_transducer_stateless7_streaming/pretrained.py \ ./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -46,7 +46,7 @@ Usage of this script:
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless7_streaming/pretrained.py \ ./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -55,7 +55,7 @@ Usage of this script:
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless7_streaming/pretrained.py \ ./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -75,7 +75,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -88,7 +87,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -106,9 +105,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to tokens.txt.""",
) )
parser.add_argument( parser.add_argument(
@ -225,13 +224,14 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = token_table["<unk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}") logging.info(f"{params}")
@ -286,6 +286,12 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -297,8 +303,8 @@ def main():
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -307,16 +313,16 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1: elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
else: else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
@ -337,12 +343,11 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append(token_ids_to_words(hyp))
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -28,7 +28,7 @@ popd
2. Export to ncnn 2. Export to ncnn
./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
@ -64,7 +64,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model from train2 import add_model_arguments, get_params, get_transducer_model
@ -75,7 +75,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import setup_logger, str2bool from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser(): def get_parser():
@ -121,10 +121,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -244,12 +244,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -26,7 +26,7 @@ Usage:
./pruned_transducer_stateless8/export.py \ ./pruned_transducer_stateless8/export.py \
--exp-dir ./pruned_transducer_stateless8/exp \ --exp-dir ./pruned_transducer_stateless8/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 9 \ --avg 9 \
--jit 1 --jit 1
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless8/export.py \ ./pruned_transducer_stateless8/export.py \
--exp-dir ./pruned_transducer_stateless8/exp \ --exp-dir ./pruned_transducer_stateless8/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -86,7 +86,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
@ -98,7 +98,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -155,10 +155,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -198,12 +198,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -20,7 +20,7 @@ You can generate the checkpoint with the following command:
./pruned_transducer_stateless8/export.py \ ./pruned_transducer_stateless8/export.py \
--exp-dir ./pruned_transducer_stateless8/exp \ --exp-dir ./pruned_transducer_stateless8/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -29,7 +29,7 @@ Usage of this script:
(1) greedy search (1) greedy search
./pruned_transducer_stateless8/pretrained.py \ ./pruned_transducer_stateless8/pretrained.py \
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
@ -37,7 +37,7 @@ Usage of this script:
(2) beam search (2) beam search
./pruned_transducer_stateless8/pretrained.py \ ./pruned_transducer_stateless8/pretrained.py \
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -46,7 +46,7 @@ Usage of this script:
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless8/pretrained.py \ ./pruned_transducer_stateless8/pretrained.py \
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -55,7 +55,7 @@ Usage of this script:
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless8/pretrained.py \ ./pruned_transducer_stateless8/pretrained.py \
--checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -75,7 +75,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -88,7 +87,7 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -106,9 +105,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to tokens.txt.""",
) )
parser.add_argument( parser.add_argument(
@ -225,13 +224,14 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = token_table["<unk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}") logging.info(f"{params}")
@ -286,6 +286,12 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -297,8 +303,8 @@ def main():
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -307,16 +313,16 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1: elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
else: else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
@ -337,12 +343,11 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append(token_ids_to_words(hyp))
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")