export score_token interface for onnx-runtime (#1050)

This commit is contained in:
PF Luo 2023-05-10 22:41:07 +08:00 committed by GitHub
parent 6c326427a0
commit 44d016e4a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 15 deletions

View File

@ -103,17 +103,14 @@ def main():
x = torch.randint(
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
)
y = torch.randint(
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
)
h0 = torch.rand(num_layers, n, hidden_size)
c0 = torch.rand(num_layers, n, hidden_size)
torch_nll, torch_h0, torch_c0 = torch_model.streaming_forward(x, y, h0, c0)
onnx_nll, onnx_h0, onnx_c0 = onnx_model(x, y, h0, c0)
torch_log_prob, torch_h0, torch_c0 = torch_model.score_token_onnx(x, h0, c0)
onnx_log_prob, onnx_h0, onnx_c0 = onnx_model(x, h0, c0)
for torch_v, onnx_v in zip(
(torch_nll, torch_h0, torch_c0), (onnx_nll, onnx_h0, onnx_c0)
(torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0)
):
assert torch.allclose(torch_v, onnx_v, atol=1e-5), (

View File

@ -235,7 +235,6 @@ def export_with_state(
embedding_dim = model.embedding_dim
x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
y = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
h0 = torch.zeros(num_layers, N, hidden_size)
c0 = torch.zeros(num_layers, N, hidden_size)
@ -252,18 +251,17 @@ def export_with_state(
torch.onnx.export(
model,
(x, y, h0, c0),
(x, h0, c0),
filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "y", "h0", "c0"],
output_names=["nll", "next_h0", "next_c0"],
input_names=["x", "h0", "c0"],
output_names=["log_softmax", "next_h0", "next_c0"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"y": {0: "N", 1: "L"},
"h0": {1: "N"},
"c0": {1: "N"},
"nll": {0: "N"},
"log_softmax": {0: "N"},
"next_h0": {1: "N"},
"next_c0": {1: "N"},
},
@ -372,7 +370,7 @@ def main():
# now for streaming export
saved_forward = model.__class__.forward
model.__class__.forward = model.__class__.streaming_forward
model.__class__.forward = model.__class__.score_token_onnx
streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx"
export_with_state(
model=model,

View File

@ -182,8 +182,8 @@ def main():
if params.jit:
logging.info("Using torch.jit.script")
model.__class__.streaming_forward = torch.jit.export(
model.__class__.streaming_forward
model.__class__.score_token_onnx = torch.jit.export(
model.__class__.score_token_onnx
)
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"

View File

@ -234,6 +234,33 @@ class RnnLmModel(torch.nn.Module):
return logits[:, 0].log_softmax(-1), states
def score_token_onnx(
self,
x: torch.Tensor,
state_h: torch.Tensor,
state_c: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Score a batch of tokens, i.e each sample in the batch should be a
single token. For example, x = torch.tensor([[5],[10],[20]])
Args:
x (torch.Tensor):
A batch of tokens
state_h:
state h of RNN has the shape of (num_layers, bs, hidden_dim)
state_c:
state c of RNN has the shape of (num_layers, bs, hidden_dim)
Returns:
_type_: _description_
"""
embedding = self.input_embedding(x)
rnn_out, (next_h0, next_c0) = self.rnn(embedding, (state_h, state_c))
logits = self.output_linear(rnn_out)
return logits[:, 0].log_softmax(-1), next_h0, next_c0
def forward_with_state(
self, tokens, token_lens, sos_id, eos_id, blank_id, state=None
):