From f13f0b990e945c9f28332da646cbbe8c06435876 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Mon, 24 Jul 2023 23:25:40 +0800 Subject: [PATCH] updates for the `pruned_transducer_stateless` recipes --- ...-pruned-transducer-stateless-2022-03-12.sh | 4 +- ...pruned-transducer-stateless2-2022-04-29.sh | 4 +- ...pruned-transducer-stateless3-2022-04-29.sh | 4 +- ...pruned-transducer-stateless3-2022-05-13.sh | 8 +-- ...pruned-transducer-stateless5-2022-05-13.sh | 4 +- ...ed-transducer-stateless7-ctc-2022-12-01.sh | 6 +-- ...transducer-stateless7-ctc-bs-2022-12-15.sh | 6 +-- ...nsducer-stateless7-streaming-2022-12-29.sh | 6 +-- ...pruned-transducer-stateless8-2022-11-14.sh | 6 +-- ...pruned-transducer-stateless2-2022-06-26.sh | 4 +- .../ASR/pruned_transducer_stateless/export.py | 24 ++++----- .../pruned_transducer_stateless/pretrained.py | 49 ++++++++++-------- .../pruned_transducer_stateless2/export.py | 22 ++++---- .../pretrained.py | 49 ++++++++++-------- .../export-onnx.py | 22 ++++---- .../pruned_transducer_stateless3/export.py | 26 +++++----- .../pretrained.py | 51 ++++++++++--------- .../pruned_transducer_stateless4/export.py | 22 ++++---- .../export-onnx-streaming.py | 22 ++++---- .../export-onnx.py | 22 ++++---- .../pruned_transducer_stateless5/export.py | 22 ++++---- .../pretrained.py | 49 ++++++++++-------- .../pruned_transducer_stateless6/export.py | 22 ++++---- .../export.py | 24 +++++---- .../pretrained.py | 51 +++++++++++-------- .../pretrained_ctc.py | 10 ++-- .../export.py | 24 +++++---- .../export_onnx.py | 28 +++++----- .../pretrained.py | 51 +++++++++++-------- .../pretrained_ctc.py | 10 ++-- .../export-for-ncnn-zh.py | 21 +++++--- .../export-for-ncnn.py | 22 ++++---- .../export-onnx-zh.py | 23 +++++---- .../export-onnx.py | 22 ++++---- .../pretrained.py | 51 ++++++++++--------- .../export-for-ncnn.py | 22 ++++---- .../pruned_transducer_stateless8/export.py | 24 +++++---- .../pretrained.py | 51 ++++++++++--------- 38 files changed, 487 insertions(+), 401 deletions(-) diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh index 6792c7088..412e3ad56 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh index dbf678d72..243b669ed 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh @@ -36,7 +36,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -49,7 +49,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh index b6d477afe..2d0f80304 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh @@ -35,7 +35,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -48,7 +48,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index efa4b53f0..3d5814c48 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -30,14 +30,14 @@ popd log "Export to torchscript model" ./pruned_transducer_stateless3/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 ./pruned_transducer_stateless3/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit-trace 1 @@ -74,7 +74,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh index 511fe0c9e..3d2442d54 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh @@ -32,7 +32,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --num-encoder-layers 18 \ --dim-feedforward 2048 \ --nhead 8 \ @@ -51,7 +51,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh index 192438353..ba7139efb 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh @@ -37,7 +37,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -74,7 +74,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh index 761eb72e2..6fc477116 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh @@ -36,7 +36,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -72,7 +72,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -85,7 +85,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh index e1e4e1f10..37b192a57 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -37,7 +37,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --decode-chunk-len 32 \ --epoch 99 \ --avg 1 \ @@ -81,7 +81,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --decode-chunk-len 32 \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ @@ -95,7 +95,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --decode-chunk-len 32 \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh index 5d9485692..4f2bfac24 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh @@ -41,7 +41,7 @@ log "Decode with models exported by torch.jit.script()" log "Export to torchscript model" ./pruned_transducer_stateless8/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model false \ --epoch 99 \ --avg 1 \ @@ -65,7 +65,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -78,7 +78,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh index 77cd59506..5cbdad16d 100755 --- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh +++ b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh @@ -32,7 +32,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --simulate-streaming 1 \ --causal-convolution 1 \ $repo/test_wavs/1089-134686-0001.wav \ @@ -47,7 +47,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --simulate-streaming 1 \ --causal-convolution 1 \ $repo/test_wavs/1089-134686-0001.wav \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index a19f9ab9a..4b20e3a2b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless/export.py \ --exp-dir ./pruned_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -87,10 +87,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -135,13 +135,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index 2ed1725b4..02f9f1b03 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,7 +78,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -97,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -237,13 +236,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.simulate_streaming: assert ( @@ -314,6 +314,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -325,8 +331,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -335,16 +341,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -365,12 +371,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index 984caf5f2..e02afa892 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -98,10 +98,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -145,12 +145,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index 013964720..029f55ba0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,7 +78,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -97,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -238,13 +237,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.simulate_streaming: assert ( @@ -315,6 +315,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -326,8 +332,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -336,16 +342,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -366,12 +372,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py index ce297ad12..26dea7e11 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless3/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 9999 \ --avg 1 \ --exp-dir $repo/exp/ @@ -48,8 +48,8 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer @@ -59,7 +59,7 @@ from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import setup_logger +from icefall.utils import num_tokens, setup_logger def get_parser(): @@ -105,10 +105,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -393,12 +393,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index f30c9df6a..925b15646 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 \ --jit 1 @@ -44,7 +44,7 @@ It will also generate 3 other files: `encoder_jit_script.pt`, ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 \ --jit-trace 1 @@ -56,7 +56,7 @@ It will generates 3 files: `encoder_jit_trace.pt`, ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -97,14 +97,14 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -150,10 +150,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -342,12 +342,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 7c3dfc660..abda4e2d4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +87,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -247,13 +246,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.simulate_streaming: assert ( @@ -324,6 +324,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -335,8 +341,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -345,16 +351,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -375,12 +381,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index 8f33f5b05..08d736f52 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless4/export.py \ --exp-dir ./pruned_transducer_stateless4/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -48,7 +48,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -59,7 +59,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -116,10 +116,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -164,12 +164,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py index 7c6be8e27..549fb13c9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless5/export-onnx-streaming.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -58,8 +58,8 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer @@ -74,7 +74,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -131,10 +131,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -489,12 +489,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py index 5ef453fac..fff0fcdd5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless5/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -55,8 +55,8 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer @@ -71,7 +71,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -128,10 +128,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -416,12 +416,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index 54f656859..e5223be26 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -48,7 +48,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -59,7 +59,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -116,10 +116,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -164,12 +164,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index 74a2210c3..304fa8693 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,6 +78,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -95,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -214,13 +215,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -275,6 +277,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -286,8 +294,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -296,16 +304,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -326,12 +334,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index 4d0d8326c..38f48b2ed 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless6/export.py \ --exp-dir ./pruned_transducer_stateless6/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -98,10 +98,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -135,12 +135,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py index c1607699f..51e62d6a8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -97,7 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -154,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -197,12 +197,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py index 2f1b1a49f..78e0fa778 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,6 +87,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -104,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -223,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -284,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -295,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -305,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -335,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py index 5d460edb5..904c1deae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py @@ -22,14 +22,14 @@ You can use the following command to get the exported models: ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 Usage of this script: (1) ctc-decoding -./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --bpe-model data/lang_bpe_500/bpe.model \ --method ctc-decoding \ @@ -38,7 +38,7 @@ Usage of this script: /path/to/bar.wav (2) 1best -./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -48,7 +48,7 @@ Usage of this script: /path/to/bar.wav (3) nbest-rescoring -./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./bruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -60,7 +60,7 @@ Usage of this script: (4) whole-lattice-rescoring -./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py index 05df8cfff..9f35cf63e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -97,7 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -154,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -197,12 +197,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py index 630a7f735..795d23c21 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py @@ -28,7 +28,7 @@ Usage: ./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 \ --onnx 1 @@ -48,7 +48,7 @@ Check `onnx_check.py` for how to use them. (2) Export to ONNX format which can be used in Triton Server ./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 \ --onnx-triton 1 @@ -86,9 +86,10 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn +from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -98,8 +99,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool -from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -156,10 +156,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -727,13 +727,15 @@ def main(): device = torch.device("cuda", 0) logging.info(f"device: {device}") + + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py index ea0fe9164..5d240cf30 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,6 +87,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -104,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -223,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -284,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -295,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -305,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -335,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py index 412631ba1..914107526 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py @@ -22,14 +22,14 @@ You can use the following command to get the exported models: ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 Usage of this script: (1) ctc-decoding -./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --bpe-model data/lang_bpe_500/bpe.model \ --method ctc-decoding \ @@ -38,7 +38,7 @@ Usage of this script: /path/to/bar.wav (2) 1best -./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -48,7 +48,7 @@ Usage of this script: /path/to/bar.wav (3) nbest-rescoring -./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./bruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -60,7 +60,7 @@ Usage of this script: (4) whole-lattice-rescoring -./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py index e196f8b7d..07de57a86 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py @@ -66,6 +66,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -76,8 +77,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -123,10 +123,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="The tokens.txt file", ) parser.add_argument( @@ -246,9 +246,14 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py index 4a16a97fb..9a6b31268 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -28,7 +28,7 @@ popd 2. Export to ncnn ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 99 \ @@ -64,7 +64,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -75,7 +75,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -121,10 +121,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -244,12 +244,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py index 142ecc341..8653126de 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py @@ -29,7 +29,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless7_streaming/export-onnx-zh.py \ - --lang-dir $repo/data/lang_char_bpe \ + --tokens $repo/data/lang_char_bpe/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -60,6 +60,7 @@ import logging from pathlib import Path from typing import Dict, List, Tuple +import k2 import onnx import torch import torch.nn as nn @@ -76,8 +77,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -134,10 +134,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="The tokens.txt file", ) parser.add_argument( @@ -493,9 +493,14 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py index c6394019e..6f84d79b4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -27,7 +27,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -48,8 +48,8 @@ import logging from pathlib import Path from typing import Dict, List, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder @@ -65,7 +65,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -122,10 +122,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -481,12 +481,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py index fb77fdd42..bc42e8d05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +87,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -225,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -286,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -297,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -307,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -337,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py index 4a16a97fb..9a6b31268 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py @@ -28,7 +28,7 @@ popd 2. Export to ncnn ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 99 \ @@ -64,7 +64,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -75,7 +75,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -121,10 +121,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -244,12 +244,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py index d4a228b47..d9697680b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless8/export.py \ --exp-dir ./pruned_transducer_stateless8/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless8/export.py \ --exp-dir ./pruned_transducer_stateless8/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -98,7 +98,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -155,10 +155,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -198,12 +198,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 486d9d74e..64b38c9d5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless8/export.py \ --exp-dir ./pruned_transducer_stateless8/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +87,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -225,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -286,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -297,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -307,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -337,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done")