From 44d016e4a74e708d85dd91116903295dbbc6c922 Mon Sep 17 00:00:00 2001 From: PF Luo Date: Wed, 10 May 2023 22:41:07 +0800 Subject: [PATCH] export score_token interface for onnx-runtime (#1050) --- icefall/rnn_lm/check-onnx-streaming.py | 9 +++------ icefall/rnn_lm/export-onnx.py | 12 +++++------- icefall/rnn_lm/export.py | 4 ++-- icefall/rnn_lm/model.py | 27 ++++++++++++++++++++++++++ 4 files changed, 37 insertions(+), 15 deletions(-) diff --git a/icefall/rnn_lm/check-onnx-streaming.py b/icefall/rnn_lm/check-onnx-streaming.py index 8850c1c71..d51a4b76b 100755 --- a/icefall/rnn_lm/check-onnx-streaming.py +++ b/icefall/rnn_lm/check-onnx-streaming.py @@ -103,17 +103,14 @@ def main(): x = torch.randint( low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64 ) - y = torch.randint( - low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64 - ) h0 = torch.rand(num_layers, n, hidden_size) c0 = torch.rand(num_layers, n, hidden_size) - torch_nll, torch_h0, torch_c0 = torch_model.streaming_forward(x, y, h0, c0) - onnx_nll, onnx_h0, onnx_c0 = onnx_model(x, y, h0, c0) + torch_log_prob, torch_h0, torch_c0 = torch_model.score_token_onnx(x, h0, c0) + onnx_log_prob, onnx_h0, onnx_c0 = onnx_model(x, h0, c0) for torch_v, onnx_v in zip( - (torch_nll, torch_h0, torch_c0), (onnx_nll, onnx_h0, onnx_c0) + (torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0) ): assert torch.allclose(torch_v, onnx_v, atol=1e-5), ( diff --git a/icefall/rnn_lm/export-onnx.py b/icefall/rnn_lm/export-onnx.py index 1d9af5e3d..dfede708b 100755 --- a/icefall/rnn_lm/export-onnx.py +++ b/icefall/rnn_lm/export-onnx.py @@ -235,7 +235,6 @@ def export_with_state( embedding_dim = model.embedding_dim x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) - y = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) h0 = torch.zeros(num_layers, N, hidden_size) c0 = torch.zeros(num_layers, N, hidden_size) @@ -252,18 +251,17 @@ def export_with_state( torch.onnx.export( model, - (x, y, h0, c0), + (x, h0, c0), filename, verbose=False, opset_version=opset_version, - input_names=["x", "y", "h0", "c0"], - output_names=["nll", "next_h0", "next_c0"], + input_names=["x", "h0", "c0"], + output_names=["log_softmax", "next_h0", "next_c0"], dynamic_axes={ "x": {0: "N", 1: "L"}, - "y": {0: "N", 1: "L"}, "h0": {1: "N"}, "c0": {1: "N"}, - "nll": {0: "N"}, + "log_softmax": {0: "N"}, "next_h0": {1: "N"}, "next_c0": {1: "N"}, }, @@ -372,7 +370,7 @@ def main(): # now for streaming export saved_forward = model.__class__.forward - model.__class__.forward = model.__class__.streaming_forward + model.__class__.forward = model.__class__.score_token_onnx streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx" export_with_state( model=model, diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py index be4e7f8c5..dadf23009 100644 --- a/icefall/rnn_lm/export.py +++ b/icefall/rnn_lm/export.py @@ -182,8 +182,8 @@ def main(): if params.jit: logging.info("Using torch.jit.script") - model.__class__.streaming_forward = torch.jit.export( - model.__class__.streaming_forward + model.__class__.score_token_onnx = torch.jit.export( + model.__class__.score_token_onnx ) model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index a8eaadc0c..5eacf5d40 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -234,6 +234,33 @@ class RnnLmModel(torch.nn.Module): return logits[:, 0].log_softmax(-1), states + def score_token_onnx( + self, + x: torch.Tensor, + state_h: torch.Tensor, + state_c: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Score a batch of tokens, i.e each sample in the batch should be a + single token. For example, x = torch.tensor([[5],[10],[20]]) + + + Args: + x (torch.Tensor): + A batch of tokens + state_h: + state h of RNN has the shape of (num_layers, bs, hidden_dim) + state_c: + state c of RNN has the shape of (num_layers, bs, hidden_dim) + + Returns: + _type_: _description_ + """ + embedding = self.input_embedding(x) + rnn_out, (next_h0, next_c0) = self.rnn(embedding, (state_h, state_c)) + logits = self.output_linear(rnn_out) + + return logits[:, 0].log_softmax(-1), next_h0, next_c0 + def forward_with_state( self, tokens, token_lens, sos_id, eos_id, blank_id, state=None ):