From d6b28a11a70871a76b66ccf80667dd1d3ac1ab17 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 11 Aug 2023 23:57:00 +0800 Subject: [PATCH 1/7] Add export script for the yesno recipe. (#1212) --- .github/workflows/run-yesno-recipe.yml | 76 +++++++- egs/yesno/ASR/tdnn/decode.py | 1 - egs/yesno/ASR/tdnn/export.py | 118 ++++++++++++ egs/yesno/ASR/tdnn/export_onnx.py | 158 ++++++++++++++++ egs/yesno/ASR/tdnn/jit_pretrained.py | 199 ++++++++++++++++++++ egs/yesno/ASR/tdnn/onnx_pretrained.py | 241 +++++++++++++++++++++++++ egs/yesno/ASR/tdnn/pretrained.py | 37 +++- 7 files changed, 813 insertions(+), 17 deletions(-) create mode 100755 egs/yesno/ASR/tdnn/export.py create mode 100755 egs/yesno/ASR/tdnn/export_onnx.py create mode 100755 egs/yesno/ASR/tdnn/jit_pretrained.py create mode 100755 egs/yesno/ASR/tdnn/onnx_pretrained.py diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 8a2c94829..57f15fe87 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -44,11 +44,6 @@ jobs: with: fetch-depth: 0 - - name: Install graphviz - shell: bash - run: | - sudo apt-get -qq install graphviz - - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -70,6 +65,7 @@ jobs: pip install --no-binary protobuf protobuf==3.20.* pip install --no-deps --force-reinstall https://huggingface.co/csukuangfj/k2/resolve/main/cpu/k2-1.24.3.dev20230508+cpu.torch1.13.1-cp38-cp38-linux_x86_64.whl + pip install kaldifeat==1.25.0.dev20230726+cpu.torch1.13.1 -f https://csukuangfj.github.io/kaldifeat/cpu.html - name: Run yesno recipe shell: bash @@ -78,9 +74,75 @@ jobs: export PYTHONPATH=$PWD:$PYTHONPATH echo $PYTHONPATH - cd egs/yesno/ASR ./prepare.sh python3 ./tdnn/train.py python3 ./tdnn/decode.py - # TODO: Check that the WER is less than some value + + - name: Test exporting to pretrained.pt + shell: bash + working-directory: ${{github.workspace}} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + echo $PYTHONPATH + + cd egs/yesno/ASR + python3 ./tdnn/export.py --epoch 14 --avg 2 + + python3 ./tdnn/pretrained.py \ + --checkpoint ./tdnn/exp/pretrained.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + + - name: Test exporting to torchscript + shell: bash + working-directory: ${{github.workspace}} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + echo $PYTHONPATH + + cd egs/yesno/ASR + python3 ./tdnn/export.py --epoch 14 --avg 2 --jit 1 + + python3 ./tdnn/jit_pretrained.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + + - name: Test exporting to onnx + shell: bash + working-directory: ${{github.workspace}} + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + echo $PYTHONPATH + + cd egs/yesno/ASR + python3 ./tdnn/export_onnx.py --epoch 14 --avg 2 + + echo "Test float32 model" + python3 ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + + + echo "Test int8 model" + python3 ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + + - name: Show generated files + shell: bash + working-directory: ${{github.workspace}} + run: | + cd egs/yesno/ASR + ls -lh tdnn/exp diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index d5efb41df..f520607af 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -65,7 +65,6 @@ def get_params() -> AttributeDict: { "exp_dir": Path("tdnn/exp/"), "lang_dir": Path("data/lang_phone"), - "lm_dir": Path("data/lm"), "feature_dim": 23, "search_beam": 20, "output_beam": 8, diff --git a/egs/yesno/ASR/tdnn/export.py b/egs/yesno/ASR/tdnn/export.py new file mode 100755 index 000000000..c40cf8cd1 --- /dev/null +++ b/egs/yesno/ASR/tdnn/export.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +""" +This file is for exporting trained models to a checkpoint +or to a torchscript model. + +(1) Generate the checkpoint tdnn/exp/pretrained.pt + +./tdnn/export.py \ + --epoch 14 \ + --avg 2 + +See ./tdnn/pretrained.py for how to use the generated file. + +(2) Generate torchscript model tdnn/exp/cpu_jit.pt + +./tdnn/export.py \ + --epoch 14 \ + --avg 2 \ + --jit 1 + +See ./tdnn/jit_pretrained.py for how to use the generated file. +""" + +import argparse +import logging + +import torch +from model import Tdnn +from train import get_params + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=14, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=2, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + + model = Tdnn( + num_features=params.feature_dim, + num_classes=max_token_id + 1, # +1 for the blank symbol + ) + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/export_onnx.py b/egs/yesno/ASR/tdnn/export_onnx.py new file mode 100755 index 000000000..9b2a56d59 --- /dev/null +++ b/egs/yesno/ASR/tdnn/export_onnx.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 + +""" +This file is for exporting trained models to onnx. + +Usage: + + ./tdnn/export_onnx.py \ + --epoch 14 \ + --avg 2 + +The above command generates the following two files: + - ./exp/model-epoch-14-avg-2.onnx + - ./exp/model-epoch-14-avg-2.int8.onnx + +See ./tdnn/onnx_pretrained.py for how to use them. +""" + +import argparse +import logging +from typing import Dict + +import onnx +import torch +from model import Tdnn +from onnxruntime.quantization import QuantType, quantize_dynamic +from train import get_params + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=14, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=2, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + + model = Tdnn( + num_features=params.feature_dim, + num_classes=max_token_id + 1, # +1 for the blank symbol + ) + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to("cpu") + model.eval() + + N = 1 + T = 100 + C = params.feature_dim + x = torch.rand(N, T, C) + + opset_version = 13 + onnx_filename = f"{params.exp_dir}/model-epoch-{params.epoch}-avg-{params.avg}.onnx" + torch.onnx.export( + model, + x, + onnx_filename, + verbose=False, + opset_version=opset_version, + input_names=["x"], + output_names=["log_prob"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "log_prob": {0: "N", 1: "T"}, + }, + ) + + logging.info(f"Saved to {onnx_filename}") + meta_data = { + "model_type": "tdnn_lstm", + "version": "1", + "model_author": "k2-fsa", + "comment": "non-streaming tdnn for the yesno recipe", + "vocab_size": max_token_id + 1, + } + + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=onnx_filename, meta_data=meta_data) + + logging.info("Generate int8 quantization models") + onnx_filename_int8 = ( + f"{params.exp_dir}/model-epoch-{params.epoch}-avg-{params.avg}.int8.onnx" + ) + + quantize_dynamic( + model_input=onnx_filename, + model_output=onnx_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + logging.info(f"Saved to {onnx_filename_int8}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/jit_pretrained.py b/egs/yesno/ASR/tdnn/jit_pretrained.py new file mode 100755 index 000000000..84390fca5 --- /dev/null +++ b/egs/yesno/ASR/tdnn/jit_pretrained.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use a torchscript model for decoding. + +Usage: + + ./tdnn/jit_pretrained.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +Note that to generate ./tdnn/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +from typing import List +import math + + +import k2 +import kaldifeat +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +from icefall.decode import get_lattice, one_best_decoding +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./tdnn/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 23, + "num_classes": 4, # [, N, SIL, Y] + "sample_rate": 8000, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + # Note: We don't use key padding mask for attention during decoding + nnet_output = model(features) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/onnx_pretrained.py b/egs/yesno/ASR/tdnn/onnx_pretrained.py new file mode 100755 index 000000000..626473b6e --- /dev/null +++ b/egs/yesno/ASR/tdnn/onnx_pretrained.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 + +""" +This file shows how to use an ONNX model for decoding with onnxruntime. + +Usage: + +(1) Use a not quantized ONNX model, i.e., a float32 model + ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +(2) Use a quantized ONNX model, i.e., an int8 model + + ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +Note that to generate ./tdnn/exp/model-epoch-14-avg-2.onnx, +and ./tdnn/exp/model-epoch-14-avg-2.onnx, +you can use ./export_onnx.py --epoch 14 --avg 2 +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import onnxruntime as ort +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + +from icefall.decode import get_lattice, one_best_decoding +from icefall.utils import AttributeDict, get_texts + + +class OnnxModel: + def __init__(self, nn_model: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + self.model = ort.InferenceSession( + nn_model, + sess_options=self.session_opts, + ) + + meta = self.model.get_modelmeta().custom_metadata_map + self.vocab_size = int(meta["vocab_size"]) + + def run( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C) + Returns: + Return a 3-D tensor log_prob of shape (N, T, C) + """ + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + }, + ) + return torch.from_numpy(out[0]) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./tdnn/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 23, + "sample_rate": 8000, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def main(): + parser = get_parser() + args = parser.parse_args() + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"device: {device}") + + logging.info(f"Loading onnx model {params.nn_model}") + model = OnnxModel(params.nn_model) + + logging.info(f"Loading HLG from {args.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + # Note: We don't use key padding mask for attention during decoding + nnet_output = model.run(features) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index 65be77db1..987c49de6 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -15,6 +15,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +This file shows how to use a checkpoint for decoding. + +Usage: + + ./tdnn/pretrained.py \ + --checkpoint ./tdnn/exp/pretrained.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +Note that to generate ./tdnn/exp/pretrained.pt, +you can use ./export.py +""" import argparse import logging @@ -43,7 +58,8 @@ def get_parser(): required=True, help="Path to the checkpoint. " "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", + "icefall.checkpoint.save_checkpoint(). " + "You can use ./tdnn/export.py to obtain it.", ) parser.add_argument( @@ -61,8 +77,7 @@ def get_parser(): nargs="+", help="The input sound file(s) to transcribe. " "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", + "For example, wav and flac are supported. ", ) return parser @@ -99,14 +114,19 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + # We use only the first channel - ans.append(wave[0]) + ans.append(wave[0].contiguous()) return ans +@torch.no_grad() def main(): parser = get_parser() args = parser.parse_args() @@ -159,8 +179,7 @@ def main(): features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) # Note: We don't use key padding mask for attention during decoding - with torch.no_grad(): - nnet_output = model(features) + nnet_output = model(features) batch_size = nnet_output.shape[0] supervision_segments = torch.tensor( From a81396b482c799b2ace2cefb79859be827b16f00 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 12 Aug 2023 16:53:59 +0800 Subject: [PATCH 2/7] Use tokens.txt to replace bpe.model (#1162) --- ...n-librispeech-conformer-ctc3-2022-11-28.sh | 10 +- ...h-lstm-transducer-stateless2-2022-09-03.sh | 6 +- ...-pruned-transducer-stateless-2022-03-12.sh | 4 +- ...pruned-transducer-stateless2-2022-04-29.sh | 4 +- ...pruned-transducer-stateless3-2022-04-29.sh | 4 +- ...pruned-transducer-stateless3-2022-05-13.sh | 8 +- ...pruned-transducer-stateless5-2022-05-13.sh | 4 +- ...pruned-transducer-stateless7-2022-11-11.sh | 6 +- ...ed-transducer-stateless7-ctc-2022-12-01.sh | 6 +- ...transducer-stateless7-ctc-bs-2023-01-29.sh | 6 +- ...nsducer-stateless7-streaming-2022-12-29.sh | 6 +- ...pruned-transducer-stateless8-2022-11-14.sh | 6 +- ...pruned-transducer-stateless2-2022-06-26.sh | 4 +- ...speech-transducer-stateless2-2022-04-19.sh | 4 +- ...un-librispeech-zipformer-mmi-2022-12-08.sh | 4 +- .../scripts/run-pre-trained-conformer-ctc.sh | 4 +- ...d-transducer-stateless-librispeech-100h.sh | 4 +- ...d-transducer-stateless-librispeech-960h.sh | 4 +- .../run-pre-trained-transducer-stateless.sh | 4 +- .github/scripts/run-pre-trained-transducer.sh | 2 +- ...enetspeech-pruned-transducer-stateless2.sh | 36 +- .github/scripts/test-ncnn-export.sh | 12 +- .github/scripts/test-onnx-export.sh | 138 ++++++- .../pruned_transducer_stateless7/export.py | 322 +--------------- .../pretrained.py | 349 +----------------- egs/librispeech/ASR/conformer_ctc/export.py | 18 +- .../ASR/conformer_ctc/pretrained.py | 40 +- egs/librispeech/ASR/conformer_ctc2/export.py | 19 +- egs/librispeech/ASR/conformer_ctc3/export.py | 23 +- .../ASR/conformer_ctc3/pretrained.py | 42 ++- .../export.py | 22 +- .../export-for-ncnn.py | 22 +- .../export-onnx.py | 25 +- .../export.py | 22 +- .../onnx_pretrained.py | 2 +- .../ASR/lstm_transducer_stateless/export.py | 25 +- .../lstm_transducer_stateless/pretrained.py | 49 +-- .../export-for-ncnn.py | 23 +- .../export-onnx-zh.py | 2 +- .../lstm_transducer_stateless2/export-onnx.py | 25 +- .../ASR/lstm_transducer_stateless2/export.py | 25 +- .../lstm_transducer_stateless2/pretrained.py | 49 +-- .../ASR/lstm_transducer_stateless3/export.py | 25 +- .../lstm_transducer_stateless3/pretrained.py | 46 ++- .../pruned_stateless_emformer_rnnt2/export.py | 23 +- .../export-onnx.py | 2 +- .../ASR/pruned_transducer_stateless/export.py | 24 +- .../pruned_transducer_stateless/pretrained.py | 49 +-- .../pruned_transducer_stateless2/export.py | 22 +- .../pretrained.py | 49 +-- .../export-onnx.py | 24 +- .../pruned_transducer_stateless3/export.py | 26 +- .../pretrained.py | 51 +-- .../pruned_transducer_stateless4/export.py | 22 +- .../export-onnx-streaming.py | 26 +- .../export-onnx.py | 26 +- .../pruned_transducer_stateless5/export.py | 22 +- .../pretrained.py | 49 +-- .../pruned_transducer_stateless6/export.py | 22 +- .../export-onnx.py | 27 +- .../pruned_transducer_stateless7/export.py | 30 +- .../pretrained.py | 55 +-- .../export.py | 24 +- .../pretrained.py | 51 +-- .../pretrained_ctc.py | 10 +- .../export.py | 24 +- .../export_onnx.py | 26 +- .../pretrained.py | 51 +-- .../pretrained_ctc.py | 10 +- .../export-for-ncnn-zh.py | 21 +- .../export-for-ncnn.py | 22 +- .../export-onnx-zh.py | 25 +- .../export-onnx.py | 24 +- .../export.py | 20 +- .../pretrained.py | 51 +-- .../export-for-ncnn.py | 22 +- .../pruned_transducer_stateless8/export.py | 24 +- .../pretrained.py | 51 +-- egs/librispeech/ASR/transducer/export.py | 22 +- egs/librispeech/ASR/transducer/pretrained.py | 33 +- .../ASR/transducer_stateless/export.py | 22 +- .../ASR/transducer_stateless/pretrained.py | 36 +- .../ASR/transducer_stateless2/export.py | 22 +- .../ASR/transducer_stateless2/pretrained.py | 36 +- .../export.py | 22 +- .../pretrained.py | 36 +- .../ASR/zipformer/export-onnx-streaming.py | 4 +- egs/librispeech/ASR/zipformer/export-onnx.py | 4 +- egs/librispeech/ASR/zipformer/export.py | 25 +- .../ASR/zipformer/jit_pretrained_ctc.py | 18 +- egs/librispeech/ASR/zipformer/onnx_check.py | 1 - .../zipformer/onnx_pretrained-streaming.py | 3 +- .../ASR/zipformer/onnx_pretrained.py | 1 - .../ASR/zipformer/pretrained_ctc.py | 20 +- egs/librispeech/ASR/zipformer_mmi/export.py | 24 +- .../ASR/zipformer_mmi/pretrained.py | 47 +-- .../export-onnx.py | 2 +- .../pretrained.py | 2 +- icefall/utils.py | 20 + 99 files changed, 1243 insertions(+), 1623 deletions(-) mode change 100755 => 120000 egs/aishell/ASR/pruned_transducer_stateless7/export.py mode change 100644 => 120000 egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh index c68ccc954..f6fe8c9b2 100755 --- a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh +++ b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh @@ -38,7 +38,7 @@ log "Decode with models exported by torch.jit.trace()" for m in ctc-decoding 1best; do ./conformer_ctc3/jit_pretrained.py \ --model-filename $repo/exp/jit_trace.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ --bpe-model $repo/data/lang_bpe_500/bpe.model \ --G $repo/data/lm/G_4_gram.pt \ @@ -53,7 +53,7 @@ log "Export to torchscript model" ./conformer_ctc3/export.py \ --exp-dir $repo/exp \ - --lang-dir $repo/data/lang_bpe_500 \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --jit-trace 1 \ --epoch 99 \ --avg 1 \ @@ -80,9 +80,9 @@ done for m in ctc-decoding 1best; do ./conformer_ctc3/pretrained.py \ --checkpoint $repo/exp/pretrained.pt \ - --words-file $repo/data/lang_bpe_500/words.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --G $repo/data/lm/G_4_gram.pt \ --method $m \ --sample-rate 16000 \ @@ -93,7 +93,7 @@ done echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then mkdir -p conformer_ctc3/exp ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt ln -s $PWD/$repo/data/lang_bpe_500 data/ diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh index 4cd2c4bec..d547bdd45 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh @@ -31,7 +31,7 @@ log "Test exporting with torch.jit.trace()" ./lstm_transducer_stateless2/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -55,7 +55,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -68,7 +68,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh index 6792c7088..412e3ad56 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh index dbf678d72..243b669ed 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh @@ -36,7 +36,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -49,7 +49,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh index b6d477afe..2d0f80304 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh @@ -35,7 +35,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -48,7 +48,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index efa4b53f0..3d5814c48 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -30,14 +30,14 @@ popd log "Export to torchscript model" ./pruned_transducer_stateless3/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 ./pruned_transducer_stateless3/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit-trace 1 @@ -74,7 +74,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh index 511fe0c9e..3d2442d54 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh @@ -32,7 +32,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --num-encoder-layers 18 \ --dim-feedforward 2048 \ --nhead 8 \ @@ -51,7 +51,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh index 2bc179c86..961dde4f4 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh @@ -33,7 +33,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -56,7 +56,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -69,7 +69,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh index 192438353..ba7139efb 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh @@ -37,7 +37,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -74,7 +74,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -87,7 +87,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh index 7d2853c17..1ecbc4798 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh @@ -36,7 +36,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -72,7 +72,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -85,7 +85,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh index e1e4e1f10..37b192a57 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -37,7 +37,7 @@ log "Export to torchscript model" ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --decode-chunk-len 32 \ --epoch 99 \ --avg 1 \ @@ -81,7 +81,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --decode-chunk-len 32 \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ @@ -95,7 +95,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --decode-chunk-len 32 \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh index 5d9485692..4f2bfac24 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh @@ -41,7 +41,7 @@ log "Decode with models exported by torch.jit.script()" log "Export to torchscript model" ./pruned_transducer_stateless8/export.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model false \ --epoch 99 \ --avg 1 \ @@ -65,7 +65,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -78,7 +78,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh index 77cd59506..5cbdad16d 100755 --- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh +++ b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh @@ -32,7 +32,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --simulate-streaming 1 \ --causal-convolution 1 \ $repo/test_wavs/1089-134686-0001.wav \ @@ -47,7 +47,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --simulate-streaming 1 \ --causal-convolution 1 \ $repo/test_wavs/1089-134686-0001.wav \ diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh index b4aca1b6b..ff77855a2 100755 --- a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh +++ b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh index a58b8ec56..c59921055 100755 --- a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh +++ b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh @@ -37,7 +37,7 @@ log "Export to torchscript model" ./zipformer_mmi/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -61,7 +61,7 @@ for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescor --method $method \ --checkpoint $repo/exp/pretrained.pt \ --lang-dir $repo/data/lang_bpe_500 \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index 125d1f3b1..a4959aa01 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -27,7 +27,7 @@ log "CTC decoding" --method ctc-decoding \ --num-classes 500 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.flac \ $repo/test_wavs/1221-135766-0001.flac \ $repo/test_wavs/1221-135766-0002.flac @@ -38,7 +38,7 @@ log "HLG decoding" --method 1best \ --num-classes 500 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ $repo/test_wavs/1089-134686-0001.flac \ diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh index 89115e88d..7b686328d 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh index 85e2c89e6..a8eeeb514 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh index 41456f11b..2e2360435 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh index 1331c966c..b865f8d13 100755 --- a/.github/scripts/run-pre-trained-transducer.sh +++ b/.github/scripts/run-pre-trained-transducer.sh @@ -27,7 +27,7 @@ log "Beam search decoding" --method beam_search \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh index 90097c752..a3a2d3080 100755 --- a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh +++ b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh @@ -17,7 +17,6 @@ git lfs install git clone $repo_url repo=$(basename $repo_url) - log "Display test files" tree $repo/ ls -lh $repo/test_wavs/*.wav @@ -29,12 +28,11 @@ popd log "Test exporting to ONNX format" -./pruned_transducer_stateless2/export.py \ +./pruned_transducer_stateless2/export-onnx.py \ --exp-dir $repo/exp \ --lang-dir $repo/data/lang_char \ --epoch 99 \ - --avg 1 \ - --onnx 1 + --avg 1 log "Export to torchscript model" @@ -59,19 +57,17 @@ log "Decode with ONNX models" ./pruned_transducer_stateless2/onnx_check.py \ --jit-filename $repo/exp/cpu_jit.pt \ - --onnx-encoder-filename $repo/exp/encoder.onnx \ - --onnx-decoder-filename $repo/exp/decoder.onnx \ - --onnx-joiner-filename $repo/exp/joiner.onnx \ - --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj.onnx \ - --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj.onnx + --onnx-encoder-filename $repo/exp/encoder-epoch-10-avg-2.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-10-avg-2.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-10-avg-2.onnx \ + --onnx-joiner-encoder-proj-filename $repo/exp/joiner_encoder_proj-epoch-10-avg-2.onnx \ + --onnx-joiner-decoder-proj-filename $repo/exp/joiner_decoder_proj-epoch-10-avg-2.onnx ./pruned_transducer_stateless2/onnx_pretrained.py \ --tokens $repo/data/lang_char/tokens.txt \ - --encoder-model-filename $repo/exp/encoder.onnx \ - --decoder-model-filename $repo/exp/decoder.onnx \ - --joiner-model-filename $repo/exp/joiner.onnx \ - --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ - --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ $repo/test_wavs/DEV_T0000000000.wav \ $repo/test_wavs/DEV_T0000000001.wav \ $repo/test_wavs/DEV_T0000000002.wav @@ -104,9 +100,9 @@ for sym in 1 2 3; do --lang-dir $repo/data/lang_char \ --decoding-method greedy_search \ --max-sym-per-frame $sym \ - $repo/test_wavs/DEV_T0000000000.wav \ - $repo/test_wavs/DEV_T0000000001.wav \ - $repo/test_wavs/DEV_T0000000002.wav + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav done for method in modified_beam_search beam_search fast_beam_search; do @@ -117,7 +113,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --beam-size 4 \ --checkpoint $repo/exp/epoch-99.pt \ --lang-dir $repo/data/lang_char \ - $repo/test_wavs/DEV_T0000000000.wav \ - $repo/test_wavs/DEV_T0000000001.wav \ - $repo/test_wavs/DEV_T0000000002.wav + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav done diff --git a/.github/scripts/test-ncnn-export.sh b/.github/scripts/test-ncnn-export.sh index ac16131d0..4073c594a 100755 --- a/.github/scripts/test-ncnn-export.sh +++ b/.github/scripts/test-ncnn-export.sh @@ -45,7 +45,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" cd exp @@ -56,11 +55,10 @@ log "Export via torch.jit.trace()" ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ - \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --num-encoder-layers 12 \ --chunk-length 32 \ --cnn-module-kernel 31 \ @@ -91,7 +89,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-iter-468000-avg-16.pt" cd exp @@ -102,7 +99,7 @@ log "Export via torch.jit.trace()" ./lstm_transducer_stateless2/export-for-ncnn.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 @@ -140,7 +137,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained.pt" cd exp @@ -148,7 +144,7 @@ ln -s pretrained.pt epoch-99.pt popd ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 99 \ @@ -199,7 +195,7 @@ ln -s pretrained.pt epoch-9999.pt popd ./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \ - --lang-dir $repo/data/lang_char_bpe \ + --tokens $repo/data/lang_char_bpe/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 9999 \ diff --git a/.github/scripts/test-onnx-export.sh b/.github/scripts/test-onnx-export.sh index 39467c44a..fcfc11fa6 100755 --- a/.github/scripts/test-onnx-export.sh +++ b/.github/scripts/test-onnx-export.sh @@ -10,7 +10,123 @@ log() { cd egs/librispeech/ASR +log "==========================================================================" +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) +pushd $repo +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +log "Export via torch.jit.script()" +./zipformer/export.py \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +log "Test export to ONNX format" +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size "16,32,64,-1" \ + --left-context-frames "64,128,256,-1" + +ls -lh $repo/exp + +log "Run onnx_check.py" + +./zipformer/onnx_check.py \ + --jit-filename $repo/exp/jit_script.pt \ + --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx + +log "Run onnx_pretrained.py" + +./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +log "Test export streaming model to ONNX format" +./zipformer/export-onnx-streaming.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 64 + +ls -lh $repo/exp + +log "Run onnx_pretrained-streaming.py" + +./zipformer/onnx_pretrained-streaming.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1-chunk-16-left-64.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1-chunk-16-left-64.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1-chunk-16-left-64.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav + +rm -rf $repo + +log "--------------------------------------------------------------------------" log "==========================================================================" repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 @@ -39,7 +155,7 @@ log "Export via torch.jit.trace()" log "Test exporting to ONNX format" ./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -88,7 +204,7 @@ popd log "Export via torch.jit.script()" ./pruned_transducer_stateless3/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 9999 \ --avg 1 \ --exp-dir $repo/exp/ \ @@ -97,7 +213,7 @@ log "Export via torch.jit.script()" log "Test exporting to ONNX format" ./pruned_transducer_stateless3/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 9999 \ --avg 1 \ --exp-dir $repo/exp/ @@ -126,7 +242,6 @@ log "Run onnx_pretrained.py" rm -rf $repo log "--------------------------------------------------------------------------" - log "==========================================================================" repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless5-2022-05-13 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url @@ -143,7 +258,7 @@ popd log "Export via torch.jit.script()" ./pruned_transducer_stateless5/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -159,7 +274,7 @@ log "Export via torch.jit.script()" log "Test exporting to ONNX format" ./pruned_transducer_stateless5/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -205,7 +320,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained.pt" cd exp @@ -215,7 +329,7 @@ popd log "Export via torch.jit.script()" ./pruned_transducer_stateless7/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -226,7 +340,7 @@ log "Export via torch.jit.script()" log "Test exporting to ONNX format" ./pruned_transducer_stateless7/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -270,7 +384,7 @@ popd log "Test exporting to ONNX format" ./conv_emformer_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -310,7 +424,7 @@ popd log "Export via torch.jit.trace()" ./lstm_transducer_stateless2/export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -320,7 +434,7 @@ log "Export via torch.jit.trace()" log "Test exporting to ONNX format" ./lstm_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export.py b/egs/aishell/ASR/pruned_transducer_stateless7/export.py deleted file mode 100755 index 1b0e8d3b9..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/export.py +++ /dev/null @@ -1,321 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Check -https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --epoch 20 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `pruned_transducer_stateless7/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless7/decode.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --lang-dir data/lang_char - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21 - # You will find the pre-trained model in icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21exp -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -import torch.nn as nn -from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt - - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export.py b/egs/aishell/ASR/pruned_transducer_stateless7/export.py new file mode 120000 index 000000000..2713792e6 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/export.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py deleted file mode 100644 index cc54027d6..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py +++ /dev/null @@ -1,348 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --epoch 20 \ - --avg 10 - -Usage of this script: - -(1) greedy search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by -./pruned_transducer_stateless7/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.lexicon import Lexicon -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=str, - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - token_table = lexicon.token_table - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp_tokens = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp_tokens = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py new file mode 120000 index 000000000..068f0f57f --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index fbcbd7b29..f0bb97560 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -23,12 +23,13 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -63,11 +64,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="""It contains language related input files such as "lexicon.txt" - """, + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -98,16 +98,16 @@ def get_params() -> AttributeDict: def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) logging.info(params) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + num_classes = num_tokens(token_table) + 1 # +1 for the blank device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 30def9c40..df3e4d819 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -24,7 +24,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from conformer import Conformer @@ -70,11 +69,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -83,10 +80,9 @@ def get_parser(): default="1best", help="""Decoding method. Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a sentence - piece model, i.e., lang_dir/bpe.model, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. + (0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file + to convert tokens to actual words or characters. It needs + neither a lexicon nor an n-gram LM. (1) 1best - Use the best path as decoding output. Only the transformer encoder output is used for decoding. We call it HLG decoding. @@ -297,6 +293,7 @@ def main(): waves = [w.to(device) for w in waves] logging.info("Decoding started") + hyps = [] features = fbank(waves) features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) @@ -313,10 +310,17 @@ def main(): if params.method == "ctc-decoding": logging.info("Use CTC decoding") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.bpe_model) max_token_id = params.num_classes - 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + H = k2.ctc_topo( max_token=max_token_id, modified=params.num_classes > 500, @@ -337,9 +341,9 @@ def main(): best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) - hyps = [s.split() for s in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method in [ "1best", "whole-lattice-rescoring", @@ -408,16 +412,16 @@ def main(): ) best_path = next(iter(best_path_dict.values())) - hyps = get_texts(best_path) word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(" ".join([word_sym_table[i] for i in hyp])) else: raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/conformer_ctc2/export.py b/egs/librispeech/ASR/conformer_ctc2/export.py index 7892b03c6..26a95dbfa 100755 --- a/egs/librispeech/ASR/conformer_ctc2/export.py +++ b/egs/librispeech/ASR/conformer_ctc2/export.py @@ -23,6 +23,7 @@ Usage: ./conformer_ctc2/export.py \ --exp-dir ./conformer_ctc2/exp \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -46,6 +47,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from decode import get_params @@ -56,8 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -123,10 +124,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="The lang dir", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -143,14 +144,14 @@ def get_parser(): def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + num_classes = num_tokens(token_table) + 1 # +1 for the blank device = torch.device("cpu") if torch.cuda.is_available(): diff --git a/egs/librispeech/ASR/conformer_ctc3/export.py b/egs/librispeech/ASR/conformer_ctc3/export.py index c5b95d981..5cb9b4b6d 100755 --- a/egs/librispeech/ASR/conformer_ctc3/export.py +++ b/egs/librispeech/ASR/conformer_ctc3/export.py @@ -25,7 +25,7 @@ Usage: ./conformer_ctc3/export.py \ --exp-dir ./conformer_ctc3/exp \ - --lang-dir data/lang_bpe_500 \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 \ --jit-trace 1 @@ -36,7 +36,7 @@ It will generates the file: `jit_trace.pt`. ./conformer_ctc3/export.py \ --exp-dir ./conformer_ctc3/exp \ - --lang-dir data/lang_bpe_500 \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -62,6 +62,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_ctc_model, get_params @@ -72,8 +73,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -130,10 +130,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", + "--tokens", + type=str, + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -171,9 +171,10 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + num_classes = num_tokens(token_table) + 1 # +1 for the blank params.vocab_size = num_classes if params.streaming_model: diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py index 880945ea0..c37b99cce 100755 --- a/egs/librispeech/ASR/conformer_ctc3/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py @@ -24,7 +24,7 @@ Usage (for non-streaming mode): (1) ctc-decoding ./conformer_ctc3/pretrained.py \ --checkpoint conformer_ctc3/exp/pretrained.pt \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method ctc-decoding \ --sample-rate 16000 \ test_wavs/1089-134686-0001.wav @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from decode import get_decoding_params @@ -114,11 +113,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -127,10 +124,9 @@ def get_parser(): default="1best", help="""Decoding method. Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a sentence - piece model, i.e., lang_dir/bpe.model, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. + (0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file + to convert tokens to actual words or characters. It needs + neither a lexicon nor an n-gram LM. (1) 1best - Use the best path as decoding output. Only the transformer encoder output is used for decoding. We call it HLG decoding. @@ -316,6 +312,7 @@ def main(): waves = [w.to(device) for w in waves] logging.info("Decoding started") + hyps = [] features = fbank(waves) feature_lengths = [f.size(0) for f in features] @@ -348,10 +345,17 @@ def main(): if params.method == "ctc-decoding": logging.info("Use CTC decoding") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.bpe_model) max_token_id = params.num_classes - 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + H = k2.ctc_topo( max_token=max_token_id, modified=False, @@ -372,9 +376,9 @@ def main(): best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) - hyps = [s.split() for s in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method in [ "1best", "nbest-rescoring", @@ -439,16 +443,16 @@ def main(): ) best_path = next(iter(best_path_dict.values())) - hyps = get_texts(best_path) word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + hyp_tokens = get_texts(best_path) + for hyp in hyp_tokens: + hyps.append(" ".join([word_sym_table[i] for i in hyp])) else: raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 09a3e96b0..67fcc35a4 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -22,7 +22,7 @@ Usage: ./conv_emformer_transducer_stateless/export.py \ --exp-dir ./conv_emformer_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 10 \ --use-averaged-model=True \ @@ -62,7 +62,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model @@ -72,7 +72,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -118,10 +118,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -166,12 +166,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py index 8fbb02f14..85dbd4661 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py @@ -8,7 +8,7 @@ for more details about how to use this file. Usage: ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ --exp-dir ./conv_emformer_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 10 \ --use-averaged-model=True \ @@ -37,7 +37,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -48,7 +48,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -94,10 +94,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -217,12 +217,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py index ad0b45bd9..cfd365207 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-onnx.py @@ -18,7 +18,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" cd exp @@ -28,7 +27,7 @@ popd 2. Export the model to ONNX ./conv_emformer_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -55,14 +54,14 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder +from emformer import Emformer from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model -from emformer import Emformer from icefall.checkpoint import ( average_checkpoints, @@ -70,7 +69,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -127,10 +126,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -484,12 +483,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index b53426c75..8e5b14903 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./conv_emformer_transducer_stateless2/export.py \ --exp-dir ./conv_emformer_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 10 \ --use-averaged-model=True \ @@ -62,7 +62,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -73,7 +73,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -119,10 +119,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -167,12 +167,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py index db92ac696..5d7e2dfcd 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/onnx_pretrained.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./conv_emformer_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py index e338342cc..c007220d5 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -26,7 +26,7 @@ Usage: ./lstm_transducer_stateless/export.py \ --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 \ --jit-trace 1 @@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`, ./lstm_transducer_stateless/export.py \ --exp-dir ./lstm_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 @@ -79,7 +79,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -91,7 +91,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -148,10 +148,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -266,12 +266,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index b3a34a9e3..119fcf1fd 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./lstm_transducer_stateless/pretrained.py \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,6 +78,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -95,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -214,13 +215,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -275,6 +277,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -286,8 +294,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -296,16 +304,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -326,12 +334,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py index 08bfcb204..2b8c92208 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-for-ncnn.py @@ -29,7 +29,7 @@ popd ./lstm_transducer_stateless2/export-for-ncnn.py \ --exp-dir $repo/exp \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -49,7 +49,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -60,7 +60,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -106,10 +106,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -221,12 +221,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py index f068f6a0f..89ced388c 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py @@ -613,7 +613,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py index acaff8540..6b6cb893f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./lstm_transducer_stateless2/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -52,8 +52,8 @@ import logging from pathlib import Path from typing import Dict, Optional, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder @@ -68,7 +68,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -125,10 +125,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -437,12 +437,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -607,7 +608,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 0adc68112..5712da25e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -27,7 +27,7 @@ Usage: ./lstm_transducer_stateless2/export.py \ --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 \ --jit-trace 1 @@ -39,7 +39,7 @@ It will generate 3 files: `encoder_jit_trace.pt`, ./lstm_transducer_stateless2/export.py \ --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 35 \ --avg 10 @@ -80,7 +80,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -92,7 +92,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -149,10 +149,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -267,12 +267,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index f3f272b9f..5d6d97320 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./lstm_transducer_stateless2/pretrained.py \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -69,7 +69,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -82,6 +81,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -98,9 +99,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -217,13 +218,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -278,6 +280,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -289,8 +297,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -299,16 +307,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -329,12 +337,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py index a82cad043..21eaa049b 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -26,7 +26,7 @@ Usage: ./lstm_transducer_stateless3/export.py \ --exp-dir ./lstm_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 40 \ --avg 20 \ --jit-trace 1 @@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`, ./lstm_transducer_stateless3/export.py \ --exp-dir ./lstm_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 40 \ --avg 20 @@ -79,7 +79,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -91,7 +91,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -148,10 +148,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to tokens.txt.", ) parser.add_argument( @@ -266,12 +266,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index f49e9c518..29a0d4d1a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./lstm_transducer_stateless3/pretrained.py \ --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./lstm_transducer_stateless3/pretrained.py \ --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./lstm_transducer_stateless3/pretrained.py \ --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -79,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -95,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -214,13 +216,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -275,6 +278,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -286,8 +295,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -296,16 +305,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -326,12 +335,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index 3612a2bfd..ec2c9d580 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -22,7 +22,7 @@ Usage: ./prunted_stateless_emformer_rnnt/export.py \ --exp-dir ./prunted_stateless_emformer_rnnt/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -48,7 +48,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model @@ -58,7 +58,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -115,10 +115,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -154,13 +154,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py index a3ebe9d8c..282238c13 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export-onnx.py @@ -508,7 +508,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index a19f9ab9a..4b20e3a2b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless/export.py \ --exp-dir ./pruned_transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -87,10 +87,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -135,13 +135,13 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size, is + # defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index 2ed1725b4..02f9f1b03 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./pruned_transducer_stateless/pretrained.py \ --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,7 +78,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -97,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -237,13 +236,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.simulate_streaming: assert ( @@ -314,6 +314,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -325,8 +331,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -335,16 +341,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -365,12 +371,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index 984caf5f2..e02afa892 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -98,10 +98,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -145,12 +145,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index 013964720..029f55ba0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./pruned_transducer_stateless2/pretrained.py \ --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,7 +78,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -97,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -238,13 +237,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.simulate_streaming: assert ( @@ -315,6 +315,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -326,8 +332,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -336,16 +342,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -366,12 +372,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py index 9645b7801..26dea7e11 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless3/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 9999 \ --avg 1 \ --exp-dir $repo/exp/ @@ -48,8 +48,8 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer @@ -59,7 +59,7 @@ from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import setup_logger +from icefall.utils import num_tokens, setup_logger def get_parser(): @@ -105,10 +105,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -393,12 +393,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -518,7 +520,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index f30c9df6a..925b15646 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 \ --jit 1 @@ -44,7 +44,7 @@ It will also generate 3 other files: `encoder_jit_script.pt`, ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 \ --jit-trace 1 @@ -56,7 +56,7 @@ It will generates 3 files: `encoder_jit_trace.pt`, ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -97,14 +97,14 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -150,10 +150,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -342,12 +342,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 7c3dfc660..abda4e2d4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./pruned_transducer_stateless3/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +87,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -247,13 +246,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.simulate_streaming: assert ( @@ -324,6 +324,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -335,8 +341,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -345,16 +351,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -375,12 +381,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index 8f33f5b05..08d736f52 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless4/export.py \ --exp-dir ./pruned_transducer_stateless4/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -48,7 +48,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -59,7 +59,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -116,10 +116,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -164,12 +164,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py index 938ff2f16..549fb13c9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx-streaming.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless5/export-onnx-streaming.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -58,13 +58,13 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer -from onnxruntime.quantization import QuantType, quantize_dynamic from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -74,7 +74,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -131,10 +131,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -489,12 +489,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -662,7 +664,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py index 20fd8dff8..fff0fcdd5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless5/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ @@ -55,13 +55,13 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from conformer import Conformer -from onnxruntime.quantization import QuantType, quantize_dynamic from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -71,7 +71,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -128,10 +128,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -416,12 +416,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -586,7 +588,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index 54f656859..e5223be26 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -48,7 +48,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -59,7 +59,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -116,10 +116,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -164,12 +164,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for if params.streaming_model: assert params.causal_convolution diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index 74a2210c3..304fa8693 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -28,7 +28,7 @@ Usage: (2) beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -37,7 +37,7 @@ Usage: (3) modified beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage: (4) fast beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -66,7 +66,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -79,6 +78,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -95,9 +96,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -214,13 +215,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -275,6 +277,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -286,8 +294,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -296,16 +304,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -326,12 +334,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index 4d0d8326c..38f48b2ed 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -22,7 +22,7 @@ Usage: ./pruned_transducer_stateless6/export.py \ --exp-dir ./pruned_transducer_stateless6/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,12 +47,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -98,10 +98,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -135,12 +135,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py index d2db92820..11c885f4d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang +# Zengrui Jin) """ This script exports a transducer model from PyTorch to ONNX. @@ -18,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained-epoch-30-avg-9.pt" cd exp @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless7/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -50,8 +50,8 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder @@ -66,7 +66,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -123,10 +123,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -411,12 +410,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -581,7 +580,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 3e3160e7e..eb4c4d282 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang +# Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -26,7 +27,7 @@ Usage: ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +46,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -65,7 +66,7 @@ you can do: --avg 1 \ --max-duration 600 \ --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model + --tokens data/lang_bpe_500/tokens.txt \ Check ./pretrained.py for its usage. @@ -86,7 +87,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -98,7 +99,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -155,10 +156,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -198,12 +198,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -292,7 +292,7 @@ def main(): model.to("cpu") model.eval() - if params.jit is True: + if params.jit: convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index d05bafcfb..86c922cda 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -20,7 +21,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +30,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +38,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +47,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +56,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +76,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +88,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens def get_parser(): @@ -106,9 +106,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -225,13 +225,13 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -286,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -297,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -307,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -337,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py index c1607699f..51e62d6a8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -97,7 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -154,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -197,12 +197,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py index 2f1b1a49f..78e0fa778 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7_ctc/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,6 +87,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -104,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -223,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -284,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -295,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -305,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -335,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py index 5d460edb5..904c1deae 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py @@ -22,14 +22,14 @@ You can use the following command to get the exported models: ./pruned_transducer_stateless7_ctc/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 Usage of this script: (1) ctc-decoding -./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --bpe-model data/lang_bpe_500/bpe.model \ --method ctc-decoding \ @@ -38,7 +38,7 @@ Usage of this script: /path/to/bar.wav (2) 1best -./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -48,7 +48,7 @@ Usage of this script: /path/to/bar.wav (3) nbest-rescoring -./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./bruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -60,7 +60,7 @@ Usage of this script: (4) whole-lattice-rescoring -./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py index 05df8cfff..9f35cf63e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -97,7 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -154,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -197,12 +197,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py index 630a7f735..d3033b888 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py @@ -28,7 +28,7 @@ Usage: ./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 \ --onnx 1 @@ -48,7 +48,7 @@ Check `onnx_check.py` for how to use them. (2) Export to ONNX format which can be used in Triton Server ./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 \ --onnx-triton 1 @@ -86,9 +86,10 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn +from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -98,8 +99,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool -from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -156,10 +156,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -728,12 +728,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py index ea0fe9164..5d240cf30 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 13 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7_ctc_bs/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,6 +87,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -104,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -223,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -284,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -295,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -305,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -335,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py index 412631ba1..914107526 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py @@ -22,14 +22,14 @@ You can use the following command to get the exported models: ./pruned_transducer_stateless7_ctc_bs/export.py \ --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 Usage of this script: (1) ctc-decoding -./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --bpe-model data/lang_bpe_500/bpe.model \ --method ctc-decoding \ @@ -38,7 +38,7 @@ Usage of this script: /path/to/bar.wav (2) 1best -./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -48,7 +48,7 @@ Usage of this script: /path/to/bar.wav (3) nbest-rescoring -./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ +./bruned_transducer_stateless7_ctc/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ @@ -60,7 +60,7 @@ Usage of this script: (4) whole-lattice-rescoring -./pruned_transducer_stateless7_ctc_bs/jit_pretrained_ctc.py \ +./pruned_transducer_stateless7_ctc_bs/pretrained_ctc.py \ --checkpoint ./pruned_transducer_stateless7_ctc_bs/exp/pretrained.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py index e196f8b7d..07de57a86 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py @@ -66,6 +66,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -76,8 +77,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -123,10 +123,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="The tokens.txt file", ) parser.add_argument( @@ -246,9 +246,14 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py index 4a16a97fb..9a6b31268 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py @@ -28,7 +28,7 @@ popd 2. Export to ncnn ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 99 \ @@ -64,7 +64,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -75,7 +75,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -121,10 +121,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -244,12 +244,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py index 04d97808d..8653126de 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py @@ -29,7 +29,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless7_streaming/export-onnx-zh.py \ - --lang-dir $repo/data/lang_char_bpe \ + --tokens $repo/data/lang_char_bpe/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -60,6 +60,7 @@ import logging from pathlib import Path from typing import Dict, List, Tuple +import k2 import onnx import torch import torch.nn as nn @@ -76,8 +77,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -134,10 +134,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="The tokens.txt file", ) parser.add_argument( @@ -493,9 +493,14 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -661,7 +666,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py index e71bcaf29..6f84d79b4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -27,7 +27,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ @@ -48,8 +48,8 @@ import logging from pathlib import Path from typing import Dict, List, Tuple +import k2 import onnx -import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder @@ -65,7 +65,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -122,10 +122,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -481,12 +481,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -652,7 +654,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py index c191b5bcc..59a7eb589 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -139,8 +139,8 @@ import argparse import logging from pathlib import Path +import k2 import onnxruntime -import sentencepiece as spm import torch import torch.nn as nn from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner @@ -154,7 +154,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -211,10 +211,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -675,12 +675,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py index fb77fdd42..bc42e8d05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless7_streaming/export.py \ --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7_streaming/pretrained.py \ --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +87,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -225,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -286,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -297,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -307,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -337,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py index 4a16a97fb..9a6b31268 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/export-for-ncnn.py @@ -28,7 +28,7 @@ popd 2. Export to ncnn ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --exp-dir $repo/exp \ --use-averaged-model 0 \ --epoch 99 \ @@ -64,7 +64,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train2 import add_model_arguments, get_params, get_transducer_model @@ -75,7 +75,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import setup_logger, str2bool +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -121,10 +121,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -244,12 +244,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py index d4a228b47..d9697680b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -26,7 +26,7 @@ Usage: ./pruned_transducer_stateless8/export.py \ --exp-dir ./pruned_transducer_stateless8/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless8/export.py \ --exp-dir ./pruned_transducer_stateless8/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -98,7 +98,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -155,10 +155,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -198,12 +198,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 486d9d74e..64b38c9d5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -20,7 +20,7 @@ You can generate the checkpoint with the following command: ./pruned_transducer_stateless8/export.py \ --exp-dir ./pruned_transducer_stateless8/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -29,7 +29,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +37,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +46,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +55,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless8/pretrained.py \ --checkpoint ./pruned_transducer_stateless8/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +75,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -88,7 +87,7 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -106,9 +105,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -225,13 +224,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -286,6 +286,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -297,8 +303,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -307,16 +313,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -337,12 +343,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py index 6db0272f0..3b9e4a5dc 100755 --- a/egs/librispeech/ASR/transducer/export.py +++ b/egs/librispeech/ASR/transducer/export.py @@ -22,7 +22,7 @@ Usage: ./transducer/export.py \ --exp-dir ./transducer/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 34 \ --avg 11 @@ -46,7 +46,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from conformer import Conformer from decoder import Decoder @@ -55,7 +55,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -90,10 +90,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -191,12 +191,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index 511610245..c2413f5de 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -19,7 +19,7 @@ Usage: ./transducer/pretrained.py \ --checkpoint ./transducer/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav \ @@ -36,8 +36,8 @@ import logging import math from typing import List +import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import beam_search, greedy_search @@ -48,7 +48,7 @@ from model import Transducer from torch.nn.utils.rnn import pad_sequence from icefall.env import get_env_info -from icefall.utils import AttributeDict +from icefall.utils import AttributeDict, num_tokens def get_parser(): @@ -66,11 +66,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="Path to tokens.txt.", ) parser.add_argument( @@ -204,12 +202,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -257,6 +257,12 @@ def main(): x=features, x_lens=feature_lengths ) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + num_waves = encoder_out.size(0) hyps = [] for i in range(num_waves): @@ -272,12 +278,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py index 89359f1a4..c397eb171 100755 --- a/egs/librispeech/ASR/transducer_stateless/export.py +++ b/egs/librispeech/ASR/transducer_stateless/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless/export.py \ --exp-dir ./transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -46,7 +46,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -56,7 +56,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -91,10 +91,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -191,12 +191,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index 915a6069d..5898dd0f5 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -47,7 +47,7 @@ Usage: (4) fast beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -80,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -96,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -213,12 +214,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -273,6 +276,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_list = fast_beam_search_one_best( @@ -318,12 +327,11 @@ def main(): raise ValueError(f"Unsupported method: {params.method}") hyp_list.append(hyp) - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps = [token_ids_to_words(hyp) for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py index d33d02642..f4b6f5554 100755 --- a/egs/librispeech/ASR/transducer_stateless2/export.py +++ b/egs/librispeech/ASR/transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless2/export.py \ --exp-dir ./transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -46,12 +46,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -86,10 +86,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -123,12 +123,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py index 0738f30c0..b69b347ef 100755 --- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -47,7 +47,7 @@ Usage: (4) fast beam search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -80,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -96,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -213,12 +214,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -273,6 +276,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_list = fast_beam_search_one_best( @@ -318,12 +327,11 @@ def main(): raise ValueError(f"Unsupported method: {params.method}") hyp_list.append(hyp) - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps = [token_ids_to_words(hyp) for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py index 3735ef452..6d31dfe34 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless_multi_datasets/export.py \ --exp-dir ./transducer_stateless_multi_datasets/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,7 +47,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -57,7 +57,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -92,10 +92,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -192,12 +192,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py index 8c7726367..4f29d6f1f 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -47,7 +47,7 @@ Usage: (4) fast beam search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -80,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -96,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -213,12 +214,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -273,6 +276,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_list = fast_beam_search_one_best( @@ -318,12 +327,11 @@ def main(): raise ValueError(f"Unsupported method: {params.method}") hyp_list.append(hyp) - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps = [token_ids_to_words(hyp) for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 3eb06f68c..a951aeef3 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp @@ -74,7 +73,6 @@ import onnx import torch import torch.nn as nn from decoder import Decoder -from export import num_tokens from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -86,7 +84,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index 724fdd2a6..e0d664009 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp @@ -71,7 +70,6 @@ import onnx import torch import torch.nn as nn from decoder import Decoder -from export import num_tokens from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -83,7 +81,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import make_pad_mask, str2bool +from icefall.utils import make_pad_mask, num_tokens, str2bool def get_parser(): diff --git a/egs/librispeech/ASR/zipformer/export.py b/egs/librispeech/ASR/zipformer/export.py index 4a48d5bad..2b8d1aaf3 100755 --- a/egs/librispeech/ASR/zipformer/export.py +++ b/egs/librispeech/ASR/zipformer/export.py @@ -160,7 +160,6 @@ with the following commands: import argparse import logging -import re from pathlib import Path from typing import List, Tuple @@ -176,27 +175,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import make_pad_mask, str2bool - - -def num_tokens( - token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") -) -> int: - """Return the number of tokens excluding those from - disambiguation symbols. - - Caution: - 0 is not a token ID so it is excluded from the return value. - """ - symbols = token_table.symbols - ans = [] - for s in symbols: - if not disambig_pattern.match(s): - ans.append(token_table[s]) - num_tokens = len(ans) - if 0 in ans: - num_tokens -= 1 - return num_tokens +from icefall.utils import make_pad_mask, num_tokens, str2bool def get_parser(): @@ -487,6 +466,8 @@ def main(): device=device, ) ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: assert params.avg > 0, params.avg start = params.epoch - params.avg diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py index 904d8cd76..660a4bfc6 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py @@ -410,10 +410,20 @@ def main(): raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" + if params.method == "ctc-decoding": + for filename, hyp in zip(params.sound_files, hyps): + words = "".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/zipformer/onnx_check.py b/egs/librispeech/ASR/zipformer/onnx_check.py index b38b875d0..93bd3a211 100755 --- a/egs/librispeech/ASR/zipformer/onnx_check.py +++ b/egs/librispeech/ASR/zipformer/onnx_check.py @@ -33,7 +33,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py index 2ce4506a8..500b2cd09 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" git lfs pull --include "exp/pretrained.pt" cd exp @@ -29,7 +28,7 @@ popd 2. Export the model to ONNX ./zipformer/export-onnx-streaming.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained.py b/egs/librispeech/ASR/zipformer/onnx_pretrained.py index e8a521460..032b07721 100755 --- a/egs/librispeech/ASR/zipformer/onnx_pretrained.py +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained.py @@ -31,7 +31,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) pushd $repo -git lfs pull --include "data/lang_bpe_500/tokens.txt" git lfs pull --include "exp/pretrained.pt" cd exp diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py index be239e9c3..9dff2e6fc 100755 --- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py @@ -274,7 +274,7 @@ def main(): params.update(vars(args)) token_table = k2.SymbolTable.from_file(params.tokens) - params.vocab_size = num_tokens(token_table) + params.vocab_size = num_tokens(token_table) + 1 # +1 for blank params.blank_id = token_table[""] assert params.blank_id == 0 @@ -429,10 +429,20 @@ def main(): raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" + if params.method == "ctc-decoding": + for filename, hyp in zip(params.sound_files, hyps): + words = "".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + words = words.replace("▁", " ").strip() + s += f"{filename}:\n{words}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/zipformer_mmi/export.py b/egs/librispeech/ASR/zipformer_mmi/export.py index 0af7bd367..1aec56420 100755 --- a/egs/librispeech/ASR/zipformer_mmi/export.py +++ b/egs/librispeech/ASR/zipformer_mmi/export.py @@ -26,7 +26,7 @@ Usage: ./zipformer_mmi/export.py \ --exp-dir ./zipformer_mmi/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./zipformer_mmi/export.py \ --exp-dir ./zipformer_mmi/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_ctc_model, get_params @@ -97,7 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -154,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -190,12 +190,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py index 0e7fd0daf..3ba4da5dd 100755 --- a/egs/librispeech/ASR/zipformer_mmi/pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py @@ -21,7 +21,7 @@ You can generate the checkpoint with the following command: ./zipformer_mmi/export.py \ --exp-dir ./zipformer_mmi/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -30,14 +30,14 @@ Usage of this script: (1) 1best ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method 1best \ /path/to/foo.wav \ /path/to/bar.wav (2) nbest ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest \ /path/to/foo.wav \ @@ -45,7 +45,7 @@ Usage of this script: (3) nbest-rescoring-LG ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest-rescoring-LG \ /path/to/foo.wav \ @@ -53,7 +53,7 @@ Usage of this script: (4) nbest-rescoring-3-gram ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest-rescoring-3-gram \ /path/to/foo.wav \ @@ -61,7 +61,7 @@ Usage of this script: (5) nbest-rescoring-4-gram ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest-rescoring-4-gram \ /path/to/foo.wav \ @@ -83,7 +83,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from decode import get_decoding_params @@ -97,7 +96,7 @@ from icefall.decode import ( one_best_decoding, ) from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import get_texts +from icefall.utils import get_texts, num_tokens def get_parser(): @@ -115,9 +114,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -247,13 +246,14 @@ def main(): params.update(get_decoding_params()) params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -298,8 +298,6 @@ def main(): features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) mmi_graph_compiler = MmiTrainingGraphCompiler( params.lang_dir, uniq_filename="lexicon.txt", @@ -313,6 +311,12 @@ def main(): if not hasattr(HP, "lm_scores"): HP.lm_scores = HP.scores.clone() + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + method = params.method assert method in ( "1best", @@ -390,14 +394,11 @@ def main(): # # token_ids is a lit-of-list of IDs token_ids = get_texts(best_path) - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] + hyps = [token_ids_to_words(ids) for ids in token_ids] + s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py index fad66986b..760fad974 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py @@ -498,7 +498,7 @@ def main(): quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul"], + op_types_to_quantize=["MatMul", "Gather"], weight_type=QuantType.QInt8, ) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py index bc499f3dd..c3d67ad92 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -320,7 +320,7 @@ def main(): s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) + words = "".join(hyp) s += f"{filename}:\n{words}\n\n" logging.info(s) diff --git a/icefall/utils.py b/icefall/utils.py index 0feff9dc8..b01cd2770 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -2060,3 +2060,23 @@ def symlink_or_copy(exp_dir: Path, src: str, dst: str): except OSError: copyfile(src=exp_dir / src, dst=exp_dir / dst) os.close(dir_fd) + + +def num_tokens( + token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") +) -> int: + """Return the number of tokens excluding those from + disambiguation symbols. + + Caution: + 0 is not a token ID so it is excluded from the return value. + """ + symbols = token_table.symbols + ans = [] + for s in symbols: + if not disambig_pattern.match(s): + ans.append(token_table[s]) + num_tokens = len(ans) + if 0 in ans: + num_tokens -= 1 + return num_tokens From dfccadc6b6551696e2dfff787f3ec102e346d4cd Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 12 Aug 2023 16:59:06 +0800 Subject: [PATCH 3/7] Fix a typo in export_onnx.py for yesno (#1213) --- egs/yesno/ASR/tdnn/export_onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/yesno/ASR/tdnn/export_onnx.py b/egs/yesno/ASR/tdnn/export_onnx.py index 9b2a56d59..2436ca81b 100755 --- a/egs/yesno/ASR/tdnn/export_onnx.py +++ b/egs/yesno/ASR/tdnn/export_onnx.py @@ -126,7 +126,7 @@ def main(): logging.info(f"Saved to {onnx_filename}") meta_data = { - "model_type": "tdnn_lstm", + "model_type": "tdnn", "version": "1", "model_author": "k2-fsa", "comment": "non-streaming tdnn for the yesno recipe", From b0e8a40c8932d82d356b8a2ad4948331eae9749e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sat, 12 Aug 2023 21:50:59 -0400 Subject: [PATCH 4/7] Speed up yesno training to finish in ~10s on CPU (#1215) --- egs/yesno/ASR/tdnn/asr_datamodule.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index 3c1682fa1..ada8c1a6c 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -209,7 +209,7 @@ class YesNoAsrDataModule(DataModule): sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, - persistent_workers=False, + persistent_workers=True, ) return train_dl @@ -236,6 +236,7 @@ class YesNoAsrDataModule(DataModule): batch_size=None, sampler=sampler, num_workers=self.args.num_workers, + persistent_workers=True, ) return test_dl From 3b5645f5944393121e52739d5b9d5ef43a7e7a0f Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 13 Aug 2023 12:37:08 +0800 Subject: [PATCH 5/7] doc updated (#1214) --- docs/source/model-export/export-model-state-dict.rst | 4 ++-- docs/source/model-export/export-ncnn-conv-emformer.rst | 3 +-- docs/source/model-export/export-ncnn-lstm.rst | 2 +- docs/source/model-export/export-ncnn-zipformer.rst | 3 +-- docs/source/model-export/export-onnx.rst | 2 +- docs/source/model-export/export-with-torch-jit-script.rst | 2 +- docs/source/model-export/export-with-torch-jit-trace.rst | 2 +- 7 files changed, 8 insertions(+), 10 deletions(-) diff --git a/docs/source/model-export/export-model-state-dict.rst b/docs/source/model-export/export-model-state-dict.rst index c3bbd5708..5596bb7a6 100644 --- a/docs/source/model-export/export-model-state-dict.rst +++ b/docs/source/model-export/export-model-state-dict.rst @@ -41,7 +41,7 @@ as an example. ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -78,7 +78,7 @@ In each recipe, there is also a file ``pretrained.py``, which can use ./pruned_transducer_stateless3/pretrained.py \ --checkpoint ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp/pretrained-iter-1224000-avg-14.pt \ - --bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model \ + --tokens ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/tokens.txt \ --method greedy_search \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav \ ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav \ diff --git a/docs/source/model-export/export-ncnn-conv-emformer.rst b/docs/source/model-export/export-ncnn-conv-emformer.rst index 12b370143..4f5535d83 100644 --- a/docs/source/model-export/export-ncnn-conv-emformer.rst +++ b/docs/source/model-export/export-ncnn-conv-emformer.rst @@ -153,11 +153,10 @@ Next, we use the following code to export our model: ./conv_emformer_transducer_stateless2/export-for-ncnn.py \ --exp-dir $dir/exp \ - --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --tokens $dir/data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 1 \ --use-averaged-model 0 \ - \ --num-encoder-layers 12 \ --chunk-length 32 \ --cnn-module-kernel 31 \ diff --git a/docs/source/model-export/export-ncnn-lstm.rst b/docs/source/model-export/export-ncnn-lstm.rst index 8e6dc7466..310c3d8e4 100644 --- a/docs/source/model-export/export-ncnn-lstm.rst +++ b/docs/source/model-export/export-ncnn-lstm.rst @@ -73,7 +73,7 @@ Next, we use the following code to export our model: ./lstm_transducer_stateless2/export-for-ncnn.py \ --exp-dir $dir/exp \ - --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --tokens $dir/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --use-averaged-model 0 \ diff --git a/docs/source/model-export/export-ncnn-zipformer.rst b/docs/source/model-export/export-ncnn-zipformer.rst index 8440d26b7..a5845b0e4 100644 --- a/docs/source/model-export/export-ncnn-zipformer.rst +++ b/docs/source/model-export/export-ncnn-zipformer.rst @@ -72,12 +72,11 @@ Next, we use the following code to export our model: dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --tokens $dir/data/lang_bpe_500/tokens.txt \ --exp-dir $dir/exp \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ - \ --decode-chunk-len 32 \ --num-left-chunks 4 \ --num-encoder-layers "2,4,3,2,4" \ diff --git a/docs/source/model-export/export-onnx.rst b/docs/source/model-export/export-onnx.rst index fb952abb7..d95f2acfe 100644 --- a/docs/source/model-export/export-onnx.rst +++ b/docs/source/model-export/export-onnx.rst @@ -71,7 +71,7 @@ Export the model to ONNX .. code-block:: bash ./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ diff --git a/docs/source/model-export/export-with-torch-jit-script.rst b/docs/source/model-export/export-with-torch-jit-script.rst index efd7dc2e1..31c8f0bf5 100644 --- a/docs/source/model-export/export-with-torch-jit-script.rst +++ b/docs/source/model-export/export-with-torch-jit-script.rst @@ -32,7 +32,7 @@ as an example in the following. ./pruned_transducer_stateless3/export.py \ --exp-dir ./pruned_transducer_stateless3/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch $epoch \ --avg $avg \ --jit 1 diff --git a/docs/source/model-export/export-with-torch-jit-trace.rst b/docs/source/model-export/export-with-torch-jit-trace.rst index 506459909..be7876ab5 100644 --- a/docs/source/model-export/export-with-torch-jit-trace.rst +++ b/docs/source/model-export/export-with-torch-jit-trace.rst @@ -33,7 +33,7 @@ as an example in the following. ./lstm_transducer_stateless2/export.py \ --exp-dir ./lstm_transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --iter $iter \ --avg $avg \ --jit-trace 1 From 9a47c08d085f00b63ce2d7c6d0fee16812691ed7 Mon Sep 17 00:00:00 2001 From: Erwan Zerhouni <61225408+ezerhouni@users.noreply.github.com> Date: Mon, 14 Aug 2023 16:10:50 +0200 Subject: [PATCH 6/7] Update padding modified beam search (#1217) --- .../beam_search.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index fd59d4b7f..97e259b40 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1008,7 +1008,7 @@ def modified_beam_search( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), context_state=None if context_graph is None else context_graph.root, timestamp=[], @@ -1217,7 +1217,7 @@ def modified_beam_search_lm_rescore( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), timestamp=[], ) @@ -1417,7 +1417,7 @@ def modified_beam_search_lm_rescore_LODR( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), timestamp=[], ) @@ -1617,7 +1617,7 @@ def _deprecated_modified_beam_search( B = HypothesisList() B.add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), timestamp=[], ) @@ -1753,7 +1753,11 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) + B.add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], log_prob=0.0, timestamp=[] + ) + ) max_sym_per_utt = 20000 @@ -2265,7 +2269,7 @@ def modified_beam_search_ngram_rescoring( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), state_cost=NgramLmStateCost(ngram_lm), ) @@ -2446,7 +2450,7 @@ def modified_beam_search_LODR( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), state=init_states, # state of the NN LM lm_score=init_score.reshape(-1), @@ -2709,7 +2713,7 @@ def modified_beam_search_lm_shallow_fusion( for i in range(N): B[i].add( Hypothesis( - ys=[blank_id] * context_size, + ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), state=init_states, lm_score=init_score.reshape(-1), From fc2df07841b3edbd7bffddfcc2e016515aa75247 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 16 Aug 2023 22:32:41 +0800 Subject: [PATCH 7/7] Add icefall tutorials for dummies. (#1220) --- docs/source/conf.py | 3 + docs/source/for-dummies/data-preparation.rst | 180 ++++++++++ docs/source/for-dummies/decoding.rst | 39 +++ docs/source/for-dummies/environment-setup.rst | 121 +++++++ docs/source/for-dummies/index.rst | 34 ++ docs/source/for-dummies/model-export.rst | 310 ++++++++++++++++++ docs/source/for-dummies/training.rst | 39 +++ docs/source/index.rst | 1 + egs/yesno/ASR/tdnn/onnx_pretrained.py | 1 + 9 files changed, 728 insertions(+) create mode 100644 docs/source/for-dummies/data-preparation.rst create mode 100644 docs/source/for-dummies/decoding.rst create mode 100644 docs/source/for-dummies/environment-setup.rst create mode 100644 docs/source/for-dummies/index.rst create mode 100644 docs/source/for-dummies/model-export.rst create mode 100644 docs/source/for-dummies/training.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index bf231e3c1..5a534e126 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -95,4 +95,7 @@ rst_epilog = """ .. _k2: https://github.com/k2-fsa/k2 .. _lhotse: https://github.com/lhotse-speech/lhotse .. _yesno: https://www.openslr.org/1/ +.. _Next-gen Kaldi: https://github.com/k2-fsa +.. _Kaldi: https://github.com/kaldi-asr/kaldi +.. _lilcom: https://github.com/danpovey/lilcom """ diff --git a/docs/source/for-dummies/data-preparation.rst b/docs/source/for-dummies/data-preparation.rst new file mode 100644 index 000000000..f03d44e79 --- /dev/null +++ b/docs/source/for-dummies/data-preparation.rst @@ -0,0 +1,180 @@ +.. _dummies_tutorial_data_preparation: + +Data Preparation +================ + +After :ref:`dummies_tutorial_environment_setup`, we can start preparing the +data for training and decoding. + +The first step is to prepare the data for training. We have already provided +`prepare.sh `_ +that would prepare everything required for training. + +.. code-block:: + + cd /tmp/icefall + export PYTHONPATH=/tmp/icefall:$PYTHONPATH + cd egs/yesno/ASR + + ./prepare.sh + +Note that in each recipe from `icefall`_, there exists a file ``prepare.sh``, +which you should run before you run anything else. + +That is all you need for data preparation. + +For the more curious +-------------------- + +If you are wondering how to prepare your own dataset, please refer to the following +URLs for more details: + + - ``_ + + It contains recipes for a variety of dataset. If you want to add your own + dataset, please read recipes in this folder first. + + - ``_ + + The `yesno`_ recipe in `lhotse`_. + +If you already have a `Kaldi`_ dataset directory, which contains files like +``wav.scp``, ``feats.scp``, then you can refer to ``_. + +A quick look to the generated files +----------------------------------- + +``./prepare.sh`` puts generated files into two directories: + + - ``download`` + - ``data`` + +download +^^^^^^^^ + +The ``download`` directory contains downloaded dataset files: + +.. code-block:: bas + + tree -L 1 ./download/ + + ./download/ + |-- waves_yesno + `-- waves_yesno.tar.gz + +.. hint:: + + Please refer to ``_ + for how the data is downloaded and extracted. + +data +^^^^ + +.. code-block:: bash + + tree ./data/ + + ./data/ + |-- fbank + | |-- yesno_cuts_test.jsonl.gz + | |-- yesno_cuts_train.jsonl.gz + | |-- yesno_feats_test.lca + | `-- yesno_feats_train.lca + |-- lang_phone + | |-- HLG.pt + | |-- L.pt + | |-- L_disambig.pt + | |-- Linv.pt + | |-- lexicon.txt + | |-- lexicon_disambig.txt + | |-- tokens.txt + | `-- words.txt + |-- lm + | |-- G.arpa + | `-- G.fst.txt + `-- manifests + |-- yesno_recordings_test.jsonl.gz + |-- yesno_recordings_train.jsonl.gz + |-- yesno_supervisions_test.jsonl.gz + `-- yesno_supervisions_train.jsonl.gz + + 4 directories, 18 files + +**data/manifests**: + + This directory contains manifests. They are used to generate files in + ``data/fbank``. + + To give you an idea of what it contains, we examine the first few lines of + the manifests related to the ``train`` dataset. + + .. code-block:: bash + + cd data/manifests + gunzip -c yesno_recordings_train.jsonl.gz | head -n 3 + + The output is given below: + + .. code-block:: bash + + {"id": "0_0_0_0_1_1_1_1", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_0_1_1_1_1.wav"}], "sampling_rate": 8000, "num_samples": 50800, "duration": 6.35, "channel_ids": [0]} + {"id": "0_0_0_1_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_1_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48880, "duration": 6.11, "channel_ids": [0]} + {"id": "0_0_1_0_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_1_0_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48160, "duration": 6.02, "channel_ids": [0]} + + Please refer to ``_ + for the meaning of each field per line. + + .. code-block:: bash + + gunzip -c yesno_supervisions_train.jsonl.gz | head -n 3 + + The output is given below: + + .. code-block:: bash + + {"id": "0_0_0_0_1_1_1_1", "recording_id": "0_0_0_0_1_1_1_1", "start": 0.0, "duration": 6.35, "channel": 0, "text": "NO NO NO NO YES YES YES YES", "language": "Hebrew"} + {"id": "0_0_0_1_0_1_1_0", "recording_id": "0_0_0_1_0_1_1_0", "start": 0.0, "duration": 6.11, "channel": 0, "text": "NO NO NO YES NO YES YES NO", "language": "Hebrew"} + {"id": "0_0_1_0_0_1_1_0", "recording_id": "0_0_1_0_0_1_1_0", "start": 0.0, "duration": 6.02, "channel": 0, "text": "NO NO YES NO NO YES YES NO", "language": "Hebrew"} + + Please refer to ``_ + for the meaning of each field per line. + +**data/fbank**: + + This directory contains everything from ``data/manifests``. Furthermore, it also contains features + for training. + + ``data/fbank/yesno_feats_train.lca`` contains the features for the train dataset. + Features are compressed using `lilcom`_. + + ``data/fbank/yesno_cuts_train.jsonl.gz`` stores the `CutSet `_, + which stores `RecordingSet `_, + `SupervisionSet `_, + and `FeatureSet `_. + + To give you an idea about what it looks like, we can run the following command: + + .. code-block:: bash + + cd data/fbank + + gunzip -c yesno_cuts_train.jsonl.gz | head -n 3 + + The output is given below: + + .. code-block:: bash + + {"id": "0_0_0_0_1_1_1_1-0", "start": 0, "duration": 6.35, "channel": 0, "supervisions": [{"id": "0_0_0_0_1_1_1_1", "recording_id": "0_0_0_0_1_1_1_1", "start": 0.0, "duration": 6.35, "channel": 0, "text": "NO NO NO NO YES YES YES YES", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 635, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.35, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "0,13000,3570", "channels": 0}, "recording": {"id": "0_0_0_0_1_1_1_1", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_0_1_1_1_1.wav"}], "sampling_rate": 8000, "num_samples": 50800, "duration": 6.35, "channel_ids": [0]}, "type": "MonoCut"} + {"id": "0_0_0_1_0_1_1_0-1", "start": 0, "duration": 6.11, "channel": 0, "supervisions": [{"id": "0_0_0_1_0_1_1_0", "recording_id": "0_0_0_1_0_1_1_0", "start": 0.0, "duration": 6.11, "channel": 0, "text": "NO NO NO YES NO YES YES NO", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 611, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.11, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "16570,12964,2929", "channels": 0}, "recording": {"id": "0_0_0_1_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_0_1_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48880, "duration": 6.11, "channel_ids": [0]}, "type": "MonoCut"} + {"id": "0_0_1_0_0_1_1_0-2", "start": 0, "duration": 6.02, "channel": 0, "supervisions": [{"id": "0_0_1_0_0_1_1_0", "recording_id": "0_0_1_0_0_1_1_0", "start": 0.0, "duration": 6.02, "channel": 0, "text": "NO NO YES NO NO YES YES NO", "language": "Hebrew"}], "features": {"type": "kaldi-fbank", "num_frames": 602, "num_features": 23, "frame_shift": 0.01, "sampling_rate": 8000, "start": 0, "duration": 6.02, "storage_type": "lilcom_chunky", "storage_path": "data/fbank/yesno_feats_train.lca", "storage_key": "32463,12936,2696", "channels": 0}, "recording": {"id": "0_0_1_0_0_1_1_0", "sources": [{"type": "file", "channels": [0], "source": "/tmp/icefall/egs/yesno/ASR/download/waves_yesno/0_0_1_0_0_1_1_0.wav"}], "sampling_rate": 8000, "num_samples": 48160, "duration": 6.02, "channel_ids": [0]}, "type": "MonoCut"} + + Note that ``yesno_cuts_train.jsonl.gz`` only stores the information about how to read the features. + The actual features are stored separately in ``data/fbank/yesno_feats_train.lca``. + +**data/lang**: + + This directory contains the lexicon. + +**data/lm**: + + This directory contains language models. diff --git a/docs/source/for-dummies/decoding.rst b/docs/source/for-dummies/decoding.rst new file mode 100644 index 000000000..3e48e8bfd --- /dev/null +++ b/docs/source/for-dummies/decoding.rst @@ -0,0 +1,39 @@ +.. _dummies_tutorial_decoding: + +Decoding +======== + +After :ref:`dummies_tutorial_training`, we can start decoding. + +The command to start the decoding is quite simple: + +.. code-block:: bash + + cd /tmp/icefall + export PYTHONPATH=/tmp/icefall:$PYTHONPATH + cd egs/yesno/ASR + + # We use CPU for decoding by setting the following environment variable + export CUDA_VISIBLE_DEVICES="" + + ./tdnn/decode.py + +The output logs are given below: + +.. literalinclude:: ./code/decoding-yesno.txt + +For the more curious +-------------------- + +.. code-block:: bash + + ./tdnn/decode.py --help + +will print the usage information about ``./tdnn/decode.py``. For instance, you +can specify: + + - ``--epoch`` to use which checkpoint for decoding + - ``--avg`` to select how many checkpoints to use for model averaging + +You usually try different combinations of ``--epoch`` and ``--avg`` and select +one that leads to the lowest WER (`Word Error Rate `_). diff --git a/docs/source/for-dummies/environment-setup.rst b/docs/source/for-dummies/environment-setup.rst new file mode 100644 index 000000000..0cb8ecc1d --- /dev/null +++ b/docs/source/for-dummies/environment-setup.rst @@ -0,0 +1,121 @@ +.. _dummies_tutorial_environment_setup: + +Environment setup +================= + +We will create an environment for `Next-gen Kaldi`_ that runs on ``CPU`` +in this tutorial. + +.. note:: + + Since the `yesno`_ dataset used in this tutorial is very tiny, training on + ``CPU`` works very well for it. + + If your dataset is very large, e.g., hundreds or thousands of hours of + training data, please follow :ref:`install icefall` to install `icefall`_ + that works with ``GPU``. + + +Create a virtual environment +---------------------------- + +.. code-block:: bash + + virtualenv -p python3 /tmp/icefall_env + +The above command creates a virtual environment in the directory ``/tmp/icefall_env``. +You can select any directory you want. + +The output of the above command is given below: + +.. code-block:: bash + + Already using interpreter /usr/bin/python3 + Using base prefix '/usr' + New python executable in /tmp/icefall_env/bin/python3 + Also creating executable in /tmp/icefall_env/bin/python + Installing setuptools, pkg_resources, pip, wheel...done. + +Now we can activate the environment using: + +.. code-block:: bash + + source /tmp/icefall_env/bin/activate + +Install dependencies +-------------------- + +.. warning:: + + Remeber to activate your virtual environment before you continue! + +After activating the virtual environment, we can use the following command +to install dependencies of `icefall`_: + +.. hint:: + + Remeber that we will run this tutorial on ``CPU``, so we install + dependencies required only by running on ``CPU``. + +.. code-block:: bash + + # Caution: Installation order matters! + + # We use torch 2.0.0 and torchaduio 2.0.0 in this tutorial. + # Other versions should also work. + + pip install torch==2.0.0+cpu torchaudio==2.0.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + + # If you are using macOS or Windows, please use the following command to install torch and torchaudio + # pip install torch==2.0.0 torchaudio==2.0.0 -f https://download.pytorch.org/whl/torch_stable.html + + # Now install k2 + # Please refer to https://k2-fsa.github.io/k2/installation/from_wheels.html#linux-cpu-example + + pip install k2==1.24.3.dev20230726+cpu.torch2.0.0 -f https://k2-fsa.github.io/k2/cpu.html + + # Install the latest version of lhotse + + pip install git+https://github.com/lhotse-speech/lhotse + + +Install icefall +--------------- + +We will put the source code of `icefall`_ into the directory ``/tmp`` +You can select any directory you want. + +.. code-block:: bash + + cd /tmp + git clone https://github.com/k2-fsa/icefall + cd icefall + pip install -r ./requirements.txt + +.. code-block:: bash + + # Anytime we want to use icefall, we have to set the following + # environment variable + + export PYTHONPATH=/tmp/icefall:$PYTHONPATH + +.. hint:: + + If you get the following error during this tutorial: + + .. code-block:: bash + + ModuleNotFoundError: No module named 'icefall' + + please set the above environment variable to fix it. + + +Congratulations! You have installed `icefall`_ successfully. + +For the more curious +-------------------- + +`icefall`_ contains a collection of Python scripts and you don't need to +use ``python3 setup.py install`` or ``pip install icefall`` to install it. +All you need to do is to download the code and set the environment variable +``PYTHONPATH``. diff --git a/docs/source/for-dummies/index.rst b/docs/source/for-dummies/index.rst new file mode 100644 index 000000000..7c0a3d8ee --- /dev/null +++ b/docs/source/for-dummies/index.rst @@ -0,0 +1,34 @@ +Icefall for dummies tutorial +============================ + +This tutorial walks you step by step about how to create a simple +ASR (`Automatic Speech Recognition `_) +system with `Next-gen Kaldi`_. + +We use the `yesno`_ dataset for demonstration. We select it out of two reasons: + + - It is quite tiny, containing only about 12 minutes of data + - The training can be finished within 20 seconds on ``CPU``. + +That also means you don't need a ``GPU`` to run this tutorial. + +Let's get started! + +Please follow items below **sequentially**. + +.. note:: + + The :ref:`dummies_tutorial_data_preparation` runs only on Linux and on macOS. + All other parts run on Linux, macOS, and Windows. + + Help from the community is appreciated to port the :ref:`dummies_tutorial_data_preparation` + to Windows. + +.. toctree:: + :maxdepth: 2 + + ./environment-setup.rst + ./data-preparation.rst + ./training.rst + ./decoding.rst + ./model-export.rst diff --git a/docs/source/for-dummies/model-export.rst b/docs/source/for-dummies/model-export.rst new file mode 100644 index 000000000..079ebc712 --- /dev/null +++ b/docs/source/for-dummies/model-export.rst @@ -0,0 +1,310 @@ +Model Export +============ + +There are three ways to export a pre-trained model. + + - Export the model parameters via `model.state_dict() `_ + - Export via `torchscript `_: either `torch.jit.script() `_ or `torch.jit.trace() `_ + - Export to `ONNX`_ via `torch.onnx.export() `_ + +Each method is explained below in detail. + +Export the model parameters via model.state_dict() +--------------------------------------------------- + +The command for this kind of export is + +.. code-block:: bash + + cd /tmp/icefall + export PYTHONPATH=/tmp/icefall:$PYTHONPATH + cd egs/yesno/ASR + + # assume that "--epoch 14 --avg 2" produces the lowest WER. + + ./tdnn/export.py --epoch 14 --avg 2 + +The output logs are given below: + +.. code-block:: bash + + 2023-08-16 20:42:03,912 INFO [export.py:76] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'jit': False} + 2023-08-16 20:42:03,913 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt + 2023-08-16 20:42:03,950 INFO [export.py:93] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt'] + 2023-08-16 20:42:03,971 INFO [export.py:106] Not using torch.jit.script + 2023-08-16 20:42:03,974 INFO [export.py:111] Saved to tdnn/exp/pretrained.pt + +We can see from the logs that the exported model is saved to the file ``tdnn/exp/pretrained.pt``. + +To give you an idea of what ``tdnn/exp/pretrained.pt`` contains, we can use the following command: + +.. code-block:: python3 + + >>> import torch + >>> m = torch.load("tdnn/exp/pretrained.pt") + >>> list(m.keys()) + ['model'] + >>> list(m["model"].keys()) + ['tdnn.0.weight', 'tdnn.0.bias', 'tdnn.2.running_mean', 'tdnn.2.running_var', 'tdnn.2.num_batches_tracked', 'tdnn.3.weight', 'tdnn.3.bias', 'tdnn.5.running_mean', 'tdnn.5.running_var', 'tdnn.5.num_batches_tracked', 'tdnn.6.weight', 'tdnn.6.bias', 'tdnn.8.running_mean', 'tdnn.8.running_var', 'tdnn.8.num_batches_tracked', 'output_linear.weight', 'output_linear.bias'] + +We can use ``tdnn/exp/pretrained.pt`` in the following way with ``./tdnn/decode.py``: + +.. code-block:: bash + + cd tdnn/exp + ln -s pretrained.pt epoch-99.pt + cd ../.. + + ./tdnn/decode.py --epoch 99 --avg 1 + +The output logs of the above command are given below: + +.. code-block:: bash + + 2023-08-16 20:45:48,089 INFO [decode.py:262] Decoding started + 2023-08-16 20:45:48,090 INFO [decode.py:263] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'feature_dim': 23, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'epoch': 99, 'avg': 1, 'export': False, 'feature_dir': PosixPath('data/fbank'), 'max_duration': 30.0, 'bucketing_sampler': False, 'num_buckets': 10, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': False, 'return_cuts': True, 'num_workers': 2, 'env_info': {'k2-version': '1.24.3', 'k2-build-type': 'Release', 'k2-with-cuda': False, 'k2-git-sha1': 'ad79f1c699c684de9785ed6ca5edb805a41f78c3', 'k2-git-date': 'Wed Jul 26 09:30:42 2023', 'lhotse-version': '1.16.0.dev+git.aa073f6.clean', 'torch-version': '2.0.0', 'torch-cuda-available': False, 'torch-cuda-version': None, 'python-version': '3.1', 'icefall-git-branch': 'master', 'icefall-git-sha1': '9a47c08-clean', 'icefall-git-date': 'Mon Aug 14 22:10:50 2023', 'icefall-path': '/private/tmp/icefall', 'k2-path': '/private/tmp/icefall_env/lib/python3.11/site-packages/k2/__init__.py', 'lhotse-path': '/private/tmp/icefall_env/lib/python3.11/site-packages/lhotse/__init__.py', 'hostname': 'fangjuns-MacBook-Pro.local', 'IP address': '127.0.0.1'}} + 2023-08-16 20:45:48,092 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt + 2023-08-16 20:45:48,103 INFO [decode.py:272] device: cpu + 2023-08-16 20:45:48,109 INFO [checkpoint.py:112] Loading checkpoint from tdnn/exp/epoch-99.pt + 2023-08-16 20:45:48,115 INFO [asr_datamodule.py:218] About to get test cuts + 2023-08-16 20:45:48,115 INFO [asr_datamodule.py:253] About to get test cuts + 2023-08-16 20:45:50,386 INFO [decode.py:203] batch 0/?, cuts processed until now is 4 + 2023-08-16 20:45:50,556 INFO [decode.py:240] The transcripts are stored in tdnn/exp/recogs-test_set.txt + 2023-08-16 20:45:50,557 INFO [utils.py:564] [test_set] %WER 0.42% [1 / 240, 0 ins, 1 del, 0 sub ] + 2023-08-16 20:45:50,558 INFO [decode.py:248] Wrote detailed error stats to tdnn/exp/errs-test_set.txt + 2023-08-16 20:45:50,559 INFO [decode.py:315] Done! + +We can see that it produces an identical WER as before. + +We can also use it to decode files with the following command: + +.. code-block:: bash + + # ./tdnn/pretrained.py requires kaldifeat + # + # Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html + # for how to install kaldifeat + + pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html + + ./tdnn/pretrained.py \ + --checkpoint ./tdnn/exp/pretrained.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +The output is given below: + +.. code-block:: bash + + 2023-08-16 20:53:19,208 INFO [pretrained.py:136] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'checkpoint': './tdnn/exp/pretrained.pt', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']} + 2023-08-16 20:53:19,208 INFO [pretrained.py:142] device: cpu + 2023-08-16 20:53:19,208 INFO [pretrained.py:144] Creating model + 2023-08-16 20:53:19,212 INFO [pretrained.py:156] Loading HLG from ./data/lang_phone/HLG.pt + 2023-08-16 20:53:19,213 INFO [pretrained.py:160] Constructing Fbank computer + 2023-08-16 20:53:19,213 INFO [pretrained.py:170] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav'] + 2023-08-16 20:53:19,224 INFO [pretrained.py:176] Decoding started + 2023-08-16 20:53:19,304 INFO [pretrained.py:212] + download/waves_yesno/0_0_0_1_0_0_0_1.wav: + NO NO NO YES NO NO NO YES + + download/waves_yesno/0_0_1_0_0_0_1_0.wav: + NO NO YES NO NO NO YES NO + + + 2023-08-16 20:53:19,304 INFO [pretrained.py:214] Decoding Done + + +Export via torch.jit.script() +----------------------------- + +The command for this kind of export is + +.. code-block:: bash + + cd /tmp/icefall + export PYTHONPATH=/tmp/icefall:$PYTHONPATH + cd egs/yesno/ASR + + # assume that "--epoch 14 --avg 2" produces the lowest WER. + + ./tdnn/export.py --epoch 14 --avg 2 --jit true + +The output logs are given below: + +.. code-block:: bash + + 2023-08-16 20:47:44,666 INFO [export.py:76] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2, 'jit': True} + 2023-08-16 20:47:44,667 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt + 2023-08-16 20:47:44,670 INFO [export.py:93] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt'] + 2023-08-16 20:47:44,677 INFO [export.py:100] Using torch.jit.script + 2023-08-16 20:47:44,843 INFO [export.py:104] Saved to tdnn/exp/cpu_jit.pt + +From the output logs we can see that the generated file is saved to ``tdnn/exp/cpu_jit.pt``. + +Don't be confused by the name ``cpu_jit.pt``. The ``cpu`` part means the model is moved to +CPU before exporting. That means, when you load it with: + +.. code-block:: bash + + torch.jit.load() + +you don't need to specify the argument `map_location `_ +and it resides on CPU by default. + +To use ``tdnn/exp/cpu_jit.pt`` with `icefall`_ to decode files, we can use: + +.. code-block:: bash + + # ./tdnn/jit_pretrained.py requires kaldifeat + # + # Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html + # for how to install kaldifeat + + pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html + + + ./tdnn/jit_pretrained.py \ + --nn-model ./tdnn/exp/cpu_jit.pt \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +The output is given below: + +.. code-block:: bash + + 2023-08-16 20:56:00,603 INFO [jit_pretrained.py:121] {'feature_dim': 23, 'num_classes': 4, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'nn_model': './tdnn/exp/cpu_jit.pt', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']} + 2023-08-16 20:56:00,603 INFO [jit_pretrained.py:127] device: cpu + 2023-08-16 20:56:00,603 INFO [jit_pretrained.py:129] Loading torchscript model + 2023-08-16 20:56:00,640 INFO [jit_pretrained.py:134] Loading HLG from ./data/lang_phone/HLG.pt + 2023-08-16 20:56:00,641 INFO [jit_pretrained.py:138] Constructing Fbank computer + 2023-08-16 20:56:00,641 INFO [jit_pretrained.py:148] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav'] + 2023-08-16 20:56:00,642 INFO [jit_pretrained.py:154] Decoding started + 2023-08-16 20:56:00,727 INFO [jit_pretrained.py:190] + download/waves_yesno/0_0_0_1_0_0_0_1.wav: + NO NO NO YES NO NO NO YES + + download/waves_yesno/0_0_1_0_0_0_1_0.wav: + NO NO YES NO NO NO YES NO + + + 2023-08-16 20:56:00,727 INFO [jit_pretrained.py:192] Decoding Done + +.. hint:: + + We provide only code for ``torch.jit.script()``. You can try ``torch.jit.trace()`` + if you want. + +Export via torch.onnx.export() +------------------------------ + +The command for this kind of export is + +.. code-block:: bash + + cd /tmp/icefall + export PYTHONPATH=/tmp/icefall:$PYTHONPATH + cd egs/yesno/ASR + + # tdnn/export_onnx.py requires onnx and onnxruntime + pip install onnx onnxruntime + + # assume that "--epoch 14 --avg 2" produces the lowest WER. + + ./tdnn/export_onnx.py \ + --epoch 14 \ + --avg 2 + +The output logs are given below: + +.. code-block:: bash + + 2023-08-16 20:59:20,888 INFO [export_onnx.py:83] {'exp_dir': PosixPath('tdnn/exp'), 'lang_dir': PosixPath('data/lang_phone'), 'lr': 0.01, 'feature_dim': 23, 'weight_decay': 1e-06, 'start_epoch': 0, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 'reset_interval': 20, 'valid_interval': 10, 'beam_size': 10, 'reduction': 'sum', 'use_double_scores': True, 'epoch': 14, 'avg': 2} + 2023-08-16 20:59:20,888 INFO [lexicon.py:168] Loading pre-compiled data/lang_phone/Linv.pt + 2023-08-16 20:59:20,892 INFO [export_onnx.py:100] averaging ['tdnn/exp/epoch-13.pt', 'tdnn/exp/epoch-14.pt'] + ================ Diagnostic Run torch.onnx.export version 2.0.0 ================ + verbose: False, log level: Level.ERROR + ======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ======================== + + 2023-08-16 20:59:21,047 INFO [export_onnx.py:127] Saved to tdnn/exp/model-epoch-14-avg-2.onnx + 2023-08-16 20:59:21,047 INFO [export_onnx.py:136] meta_data: {'model_type': 'tdnn', 'version': '1', 'model_author': 'k2-fsa', 'comment': 'non-streaming tdnn for the yesno recipe', 'vocab_size': 4} + 2023-08-16 20:59:21,049 INFO [export_onnx.py:140] Generate int8 quantization models + 2023-08-16 20:59:21,075 INFO [onnx_quantizer.py:538] Quantization parameters for tensor:"/Transpose_1_output_0" not specified + 2023-08-16 20:59:21,081 INFO [export_onnx.py:151] Saved to tdnn/exp/model-epoch-14-avg-2.int8.onnx + +We can see from the logs that it generates two files: + + - ``tdnn/exp/model-epoch-14-avg-2.onnx`` (ONNX model with ``float32`` weights) + - ``tdnn/exp/model-epoch-14-avg-2.int8.onnx`` (ONNX model with ``int8`` weights) + +To use the generated ONNX model files for decoding with `onnxruntime`_, we can use + +.. code-block:: bash + + # ./tdnn/onnx_pretrained.py requires kaldifeat + # + # Please refer to https://csukuangfj.github.io/kaldifeat/installation/from_wheels.html + # for how to install kaldifeat + + pip install kaldifeat==1.25.0.dev20230726+cpu.torch2.0.0 -f https://csukuangfj.github.io/kaldifeat/cpu.html + + ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +The output is given below: + +.. code-block:: bash + + 2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:166] {'feature_dim': 23, 'sample_rate': 8000, 'search_beam': 20, 'output_beam': 8, 'min_active_states': 30, 'max_active_states': 10000, 'use_double_scores': True, 'nn_model': './tdnn/exp/model-epoch-14-avg-2.onnx', 'words_file': './data/lang_phone/words.txt', 'HLG': './data/lang_phone/HLG.pt', 'sound_files': ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav']} + 2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:171] device: cpu + 2023-08-16 21:03:24,260 INFO [onnx_pretrained.py:173] Loading onnx model ./tdnn/exp/model-epoch-14-avg-2.onnx + 2023-08-16 21:03:24,267 INFO [onnx_pretrained.py:176] Loading HLG from ./data/lang_phone/HLG.pt + 2023-08-16 21:03:24,270 INFO [onnx_pretrained.py:180] Constructing Fbank computer + 2023-08-16 21:03:24,273 INFO [onnx_pretrained.py:190] Reading sound files: ['download/waves_yesno/0_0_0_1_0_0_0_1.wav', 'download/waves_yesno/0_0_1_0_0_0_1_0.wav'] + 2023-08-16 21:03:24,279 INFO [onnx_pretrained.py:196] Decoding started + 2023-08-16 21:03:24,318 INFO [onnx_pretrained.py:232] + download/waves_yesno/0_0_0_1_0_0_0_1.wav: + NO NO NO YES NO NO NO YES + + download/waves_yesno/0_0_1_0_0_0_1_0.wav: + NO NO YES NO NO NO YES NO + + + 2023-08-16 21:03:24,318 INFO [onnx_pretrained.py:234] Decoding Done + +.. note:: + + To use the ``int8`` ONNX model for decoding, please use: + + .. code-block:: bash + + ./tdnn/onnx_pretrained.py \ + --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ + --HLG ./data/lang_phone/HLG.pt \ + --words-file ./data/lang_phone/words.txt \ + download/waves_yesno/0_0_0_1_0_0_0_1.wav \ + download/waves_yesno/0_0_1_0_0_0_1_0.wav + +For the more curious +-------------------- + +If you are wondering how to deploy the model without ``torch``, please +continue reading. We will show how to use `sherpa-onnx`_ to run the +exported ONNX models, which depends only on `onnxruntime`_ and does not +depend on ``torch``. + +In this tutorial, we will only demonstrate the usage of `sherpa-onnx`_ with the +pre-trained model of the `yesno`_ recipe. There are also other two frameworks +available: + + - `sherpa`_. It works with torchscript models. + - `sherpa-ncnn`_. It works with models exported using :ref:`icefall_export_to_ncnn` with `ncnn`_ + +Please see ``_ for further details. diff --git a/docs/source/for-dummies/training.rst b/docs/source/for-dummies/training.rst new file mode 100644 index 000000000..816ef2d3b --- /dev/null +++ b/docs/source/for-dummies/training.rst @@ -0,0 +1,39 @@ +.. _dummies_tutorial_training: + +Training +======== + +After :ref:`dummies_tutorial_data_preparation`, we can start training. + +The command to start the training is quite simple: + +.. code-block:: bash + + cd /tmp/icefall + export PYTHONPATH=/tmp/icefall:$PYTHONPATH + cd egs/yesno/ASR + + # We use CPU for training by setting the following environment variable + export CUDA_VISIBLE_DEVICES="" + + ./tdnn/train.py + +That's it! + +You can find the training logs below: + +.. literalinclude:: ./code/train-yesno.txt + +For the more curious +-------------------- + +.. code-block:: bash + + ./tdnn/train.py --help + +will print the usage information about ``./tdnn/train.py``. For instance, you +can specify the number of epochs to train and the location to save the training +results. + +The training text logs are saved in ``tdnn/exp/log`` while the tensorboard +logs are in ``tdnn/exp/tensorboard``. diff --git a/docs/source/index.rst b/docs/source/index.rst index 0fa8fdd1c..fb539d3f2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -20,6 +20,7 @@ speech recognition recipes using `k2 `_. :maxdepth: 2 :caption: Contents: + for-dummies/index.rst installation/index docker/index faqs diff --git a/egs/yesno/ASR/tdnn/onnx_pretrained.py b/egs/yesno/ASR/tdnn/onnx_pretrained.py index 626473b6e..b23a2a381 100755 --- a/egs/yesno/ASR/tdnn/onnx_pretrained.py +++ b/egs/yesno/ASR/tdnn/onnx_pretrained.py @@ -6,6 +6,7 @@ This file shows how to use an ONNX model for decoding with onnxruntime. Usage: (1) Use a not quantized ONNX model, i.e., a float32 model + ./tdnn/onnx_pretrained.py \ --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ --HLG ./data/lang_phone/HLG.pt \