From 8d269156a01c420e86b0ca6115592f41119bf48d Mon Sep 17 00:00:00 2001 From: Yifan Yang Date: Fri, 2 Jun 2023 14:52:12 +0800 Subject: [PATCH] Add export-onnx.py --- icefall/rnn_lm/export-onnx.py | 208 ++++++++++++++++++++++++---------- icefall/rnn_lm/export-onnx.sh | 1 + icefall/rnn_lm/export.py | 4 +- 3 files changed, 155 insertions(+), 58 deletions(-) mode change 100644 => 100755 icefall/rnn_lm/export.py diff --git a/icefall/rnn_lm/export-onnx.py b/icefall/rnn_lm/export-onnx.py index 1070d443a..2506d5a33 100755 --- a/icefall/rnn_lm/export-onnx.py +++ b/icefall/rnn_lm/export-onnx.py @@ -1,6 +1,42 @@ #!/usr/bin/env python3 # -# Copyright 2023 Xiaomi Corporation +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +""" +This script exports a transducer model from PyTorch to ONNX. + +Export the model to ONNX + +./rnn_lm/export-onnx.py \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir ./rnn_lm/exp + +It will generate the following 4 files inside ./rnn_lm/exp: + + - no-state-epoch-99-avg-1.int8.onnx + - no-state-epoch-99-avg-1.int8.onnx + - with-state-epoch-99-avg-1.int8.onnx + - with-state-epoch-99-avg-1.int8.onnx +""" import argparse import logging @@ -13,7 +49,12 @@ from model import RnnLmModel from onnxruntime.quantization import QuantType, quantize_dynamic from train import get_params -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) from icefall.utils import AttributeDict, str2bool @@ -37,10 +78,6 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]): # A wrapper for RnnLm model to simpily the C++ calling code # when exporting the model to ONNX. -# -# TODO(fangjun): The current wrapper works only for non-streaming ASR -# since we don't expose the LM state and it is used to score -# a complete sentence at once. class RnnLmModelWrapper(torch.nn.Module): def __init__(self, model: RnnLmModel, sos_id: int, eos_id: int): super().__init__() @@ -91,18 +128,10 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=29, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + default=20, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", ) parser.add_argument( @@ -115,6 +144,35 @@ def get_parser(): """, ) + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="rnn_lm/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + parser.add_argument( "--vocab-size", type=int, @@ -152,15 +210,6 @@ def get_parser(): """, ) - parser.add_argument( - "--exp-dir", - type=str, - default="rnn_lm/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - return parser @@ -308,37 +357,82 @@ def main(): model.to(device) - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.to("cpu") model.eval() diff --git a/icefall/rnn_lm/export-onnx.sh b/icefall/rnn_lm/export-onnx.sh index 6e3262b5e..23c8b4aac 100755 --- a/icefall/rnn_lm/export-onnx.sh +++ b/icefall/rnn_lm/export-onnx.sh @@ -18,6 +18,7 @@ python3 ./export-onnx.py \ --exp-dir ./icefall-librispeech-rnn-lm/exp \ --epoch 99 \ --avg 1 \ + --use-averaged-model 0 \ --vocab-size 500 \ --embedding-dim 2048 \ --hidden-dim 2048 \ diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py old mode 100644 new mode 100755 index a0b7598fb..8657ea9f2 --- a/icefall/rnn_lm/export.py +++ b/icefall/rnn_lm/export.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Yifan Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -44,6 +45,7 @@ for how to use the exported models outside of icefall. ./rnn_lm/export.py \ --exp-dir ./rnn_lm/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ --epoch 20 \ --avg 10