#!/usr/bin/env python3 # # Copyright 2023 Xiaomi Corporation import argparse import logging from pathlib import Path import onnx import torch from model import RnnLmModel from onnxruntime.quantization import QuantType, quantize_dynamic from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import AttributeDict, str2bool from typing import Dict from train import get_params def add_meta_data(filename: str, meta_data: Dict[str, str]): """Add meta data to an ONNX model. It is changed in-place. Args: filename: Filename of the ONNX model to be changed. meta_data: Key-value pairs. """ model = onnx.load(filename) for key, value in meta_data.items(): meta = model.metadata_props.add() meta.key = key meta.value = value onnx.save(model, filename) # A wrapper for RnnLm model to simpily the C++ calling code # when exporting the model to ONNX. # # TODO(fangjun): The current wrapper works only for non-streaming ASR # since we don't expose the LM state and it is used to score # a complete sentence at once. class RnnLmModelWrapper(torch.nn.Module): def __init__(self, model: RnnLmModel, sos_id: int, eos_id: int): super().__init__() self.model = model self.sos_id = sos_id self.eos_id = eos_id def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: """ Args: x: A 2-D tensor of shape (N, L) with dtype torch.int64. It does not contain SOS or EOS. We will add SOS and EOS inside this function. x_lens: A 1-D tensor of shape (N,) with dtype torch.int64. It contains number of valid tokens in ``x`` before padding. Returns: Return a 1-D tensor of shape (N,) containing negative loglikelihood. Its dtype is torch.float32 """ N = x.size(0) sos_tensor = torch.full((1,), fill_value=self.sos_id, dtype=x.dtype).expand( N, 1 ) sos_x = torch.cat([sos_tensor, x], dim=1) pad_col = torch.zeros((1,), dtype=x.dtype).expand(N, 1) x_eos = torch.cat([x, pad_col], dim=1) row_index = torch.arange(0, N, dtype=x.dtype) x_eos[row_index, x_lens] = self.eos_id # use x_lens + 1 here since we prepended x with sos return ( self.model(x=sos_x, y=x_eos, lengths=x_lens + 1) .to(torch.float32) .sum(dim=1) ) def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--epoch", type=int, default=29, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=5, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", ) parser.add_argument( "--iter", type=int, default=0, help="""If positive, --epoch is ignored and it will use the checkpoint exp_dir/checkpoint-iter.pt. You can specify --avg to use more checkpoints for model averaging. """, ) parser.add_argument( "--vocab-size", type=int, default=500, help="Vocabulary size of the model", ) parser.add_argument( "--embedding-dim", type=int, default=2048, help="Embedding dim of the model", ) parser.add_argument( "--hidden-dim", type=int, default=2048, help="Hidden dim of the model", ) parser.add_argument( "--num-layers", type=int, default=3, help="Number of RNN layers the model", ) parser.add_argument( "--tie-weights", type=str2bool, default=True, help="""True to share the weights between the input embedding layer and the last output linear layer """, ) parser.add_argument( "--exp-dir", type=str, default="rnn_lm/exp", help="""It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """, ) return parser def export_without_state( model: RnnLmModel, filename: str, params: AttributeDict, opset_version: int, ): model_wrapper = RnnLmModelWrapper( model, sos_id=params.sos_id, eos_id=params.eos_id, ) N = 1 L = 20 x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) x_lens = torch.full((N,), fill_value=L, dtype=torch.int64) # Note(fangjun): The following warnings can be ignored. # We can use ./check-onnx.py to validate the exported model with batch_size > 1 """ torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX with a batch_size other than 1, " + """ torch.onnx.export( model_wrapper, (x, x_lens), filename, verbose=False, opset_version=opset_version, input_names=["x", "x_lens"], output_names=["nll"], dynamic_axes={ "x": {0: "N", 1: "L"}, "x_lens": {0: "N"}, "nll": {0: "N"}, }, ) meta_data = { "model_type": "rnnlm", "version": "1", "model_author": "k2-fsa", "comment": "rnnlm without state", "sos_id": str(params.sos_id), "eos_id": str(params.eos_id), "vocab_size": str(params.vocab_size), "url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm", } logging.info(f"meta_data: {meta_data}") add_meta_data(filename=filename, meta_data=meta_data) def export_with_state( model: RnnLmModel, filename: str, params: AttributeDict, opset_version: int, ): N = 1 L = 20 num_layers = model.rnn.num_layers hidden_size = model.rnn.hidden_size embedding_dim = model.embedding_dim x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) h0 = torch.zeros(num_layers, N, hidden_size) c0 = torch.zeros(num_layers, N, hidden_size) # Note(fangjun): The following warnings can be ignored. # We can use ./check-onnx.py to validate the exported model with batch_size > 1 """ torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX with a batch_size other than 1, " + """ torch.onnx.export( model, (x, h0, c0), filename, verbose=False, opset_version=opset_version, input_names=["x", "h0", "c0"], output_names=["log_softmax", "next_h0", "next_c0"], dynamic_axes={ "x": {0: "N", 1: "L"}, "h0": {1: "N"}, "c0": {1: "N"}, "log_softmax": {0: "N"}, "next_h0": {1: "N"}, "next_c0": {1: "N"}, }, ) meta_data = { "model_type": "rnnlm", "version": "1", "model_author": "k2-fsa", "comment": "rnnlm state", "sos_id": str(params.sos_id), "eos_id": str(params.eos_id), "vocab_size": str(params.vocab_size), "num_layers": str(num_layers), "hidden_size": str(hidden_size), "embedding_dim": str(embedding_dim), "url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm", } logging.info(f"meta_data: {meta_data}") add_meta_data(filename=filename, meta_data=meta_data) @torch.no_grad() def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) params = get_params() params.update(vars(args)) logging.info(params) device = torch.device("cpu") logging.info(f"device: {device}") model = RnnLmModel( vocab_size=params.vocab_size, embedding_dim=params.embedding_dim, hidden_dim=params.hidden_dim, num_layers=params.num_layers, tie_weights=params.tie_weights, ) model.to(device) if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ : params.avg ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( f"Not enough checkpoints ({len(filenames)}) found for" f" --iter {params.iter}, --avg {params.avg}" ) logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict( average_checkpoints(filenames, device=device), strict=False ) elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1 filenames = [] for i in range(start, params.epoch + 1): if i >= 0: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict( average_checkpoints(filenames, device=device), strict=False ) model.to("cpu") model.eval() if params.iter > 0: suffix = f"iter-{params.iter}" else: suffix = f"epoch-{params.epoch}" suffix += f"-avg-{params.avg}" opset_version = 13 logging.info("Exporting model without state") filename = params.exp_dir / f"no-state-{suffix}.onnx" export_without_state( model=model, filename=filename, params=params, opset_version=opset_version, ) filename_int8 = params.exp_dir / f"no-state-{suffix}.int8.onnx" quantize_dynamic( model_input=filename, model_output=filename_int8, weight_type=QuantType.QInt8, ) # now for streaming export saved_forward = model.__class__.forward model.__class__.forward = model.__class__.score_token_onnx streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx" export_with_state( model=model, filename=streaming_filename, params=params, opset_version=opset_version, ) model.__class__.forward = saved_forward streaming_filename_int8 = params.exp_dir / f"with-state-{suffix}.int8.onnx" quantize_dynamic( model_input=streaming_filename, model_output=streaming_filename_int8, weight_type=QuantType.QInt8, ) if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) main()