From 90dc5772ec254451014fff2d3ca57e3efe6e8355 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 25 Jul 2022 20:34:13 +0800 Subject: [PATCH] Support decoding from a torchscript model. --- .../pruned_transducer_stateless2/quantize.py | 2 +- .../pruned_transducer_stateless3/decode.py | 111 ++++++------- .../jit_decode.py | 151 ++++++++++++++++++ 3 files changed, 209 insertions(+), 55 deletions(-) create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless3/jit_decode.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/quantize.py b/egs/librispeech/ASR/pruned_transducer_stateless2/quantize.py index 9380f74ad..46f16ec3c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/quantize.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/quantize.py @@ -112,7 +112,7 @@ def dynamic_quantize( """Apply post-training dynamic quantization to a given model. It is also known as post-training weight-only quantization. - Weight are quantized to tensors of dtype torch.qint8. + Weights are quantized to tensors of dtype torch.qint8. Only nn.Linear layers are quantized at present. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index c3a03f2e1..f78cdc371 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -139,60 +139,7 @@ from icefall.utils import ( LOG_EPS = math.log(1e-10) -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless3/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - +def add_decoding_arguments(parser): parser.add_argument( "--decoding-method", type=str, @@ -401,6 +348,62 @@ def get_parser(): """, ) + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + add_decoding_arguments(parser) add_model_arguments(parser) return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_decode.py new file mode 100755 index 000000000..7c306d4f1 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_decode.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes a torchscript model, either quantized or not, and uses +it for decoding. +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from asr_datamodule import AsrDataModule +from decode import add_decoding_arguments, decode_dataset, save_results +from librispeech import LibriSpeech +from train import add_model_arguments, get_params + +from icefall.utils import setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + help="It specifies the path to load the torchscript model", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="Directory to save the decoding results", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + add_decoding_arguments(parser) + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + + # We add only greedy_search for simplicity + assert args.decoding_method == "greedy_search" + + params = get_params() + params.update(vars(args)) + + params.nn_model_filename = Path(args.nn_model_filename) + assert params.nn_model_filename.is_file(), params.nn_model_filename + + params.res_dir = Path(params.exp_dir) / Path(params.nn_model_filename).stem + params.res_dir = params.res_dir / params.decoding_method + + setup_logger(f"{params.res_dir}/log-decode") + + logging.info("Decoding started") + + model = torch.jit.load(params.nn_model_filename) + + device = torch.device("cpu") + if torch.cuda.is_available() and hasattr( + model.simple_lm_proj, "_packed_params" + ): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + model.to(device) + model.device = device + model.unk_id = params.unk_id + + logging.info(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + asr_datamodule = AsrDataModule(args) + librispeech = LibriSpeech(manifest_dir=args.manifest_dir) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = asr_datamodule.test_dataloaders(test_clean_cuts) + test_other_dl = asr_datamodule.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=None, + decoding_graph=None, + G=None, + rnn_lm_model=None, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main()