export score_token interface for onnx-runtime (#1050)

This commit is contained in:
PF Luo 2023-05-10 22:41:07 +08:00 committed by GitHub
parent 6c326427a0
commit 44d016e4a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 15 deletions

View File

@ -103,17 +103,14 @@ def main():
x = torch.randint( x = torch.randint(
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64 low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
) )
y = torch.randint(
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
)
h0 = torch.rand(num_layers, n, hidden_size) h0 = torch.rand(num_layers, n, hidden_size)
c0 = torch.rand(num_layers, n, hidden_size) c0 = torch.rand(num_layers, n, hidden_size)
torch_nll, torch_h0, torch_c0 = torch_model.streaming_forward(x, y, h0, c0) torch_log_prob, torch_h0, torch_c0 = torch_model.score_token_onnx(x, h0, c0)
onnx_nll, onnx_h0, onnx_c0 = onnx_model(x, y, h0, c0) onnx_log_prob, onnx_h0, onnx_c0 = onnx_model(x, h0, c0)
for torch_v, onnx_v in zip( for torch_v, onnx_v in zip(
(torch_nll, torch_h0, torch_c0), (onnx_nll, onnx_h0, onnx_c0) (torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0)
): ):
assert torch.allclose(torch_v, onnx_v, atol=1e-5), ( assert torch.allclose(torch_v, onnx_v, atol=1e-5), (

View File

@ -235,7 +235,6 @@ def export_with_state(
embedding_dim = model.embedding_dim embedding_dim = model.embedding_dim
x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
y = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
h0 = torch.zeros(num_layers, N, hidden_size) h0 = torch.zeros(num_layers, N, hidden_size)
c0 = torch.zeros(num_layers, N, hidden_size) c0 = torch.zeros(num_layers, N, hidden_size)
@ -252,18 +251,17 @@ def export_with_state(
torch.onnx.export( torch.onnx.export(
model, model,
(x, y, h0, c0), (x, h0, c0),
filename, filename,
verbose=False, verbose=False,
opset_version=opset_version, opset_version=opset_version,
input_names=["x", "y", "h0", "c0"], input_names=["x", "h0", "c0"],
output_names=["nll", "next_h0", "next_c0"], output_names=["log_softmax", "next_h0", "next_c0"],
dynamic_axes={ dynamic_axes={
"x": {0: "N", 1: "L"}, "x": {0: "N", 1: "L"},
"y": {0: "N", 1: "L"},
"h0": {1: "N"}, "h0": {1: "N"},
"c0": {1: "N"}, "c0": {1: "N"},
"nll": {0: "N"}, "log_softmax": {0: "N"},
"next_h0": {1: "N"}, "next_h0": {1: "N"},
"next_c0": {1: "N"}, "next_c0": {1: "N"},
}, },
@ -372,7 +370,7 @@ def main():
# now for streaming export # now for streaming export
saved_forward = model.__class__.forward saved_forward = model.__class__.forward
model.__class__.forward = model.__class__.streaming_forward model.__class__.forward = model.__class__.score_token_onnx
streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx" streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx"
export_with_state( export_with_state(
model=model, model=model,

View File

@ -182,8 +182,8 @@ def main():
if params.jit: if params.jit:
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model.__class__.streaming_forward = torch.jit.export( model.__class__.score_token_onnx = torch.jit.export(
model.__class__.streaming_forward model.__class__.score_token_onnx
) )
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -234,6 +234,33 @@ class RnnLmModel(torch.nn.Module):
return logits[:, 0].log_softmax(-1), states return logits[:, 0].log_softmax(-1), states
def score_token_onnx(
self,
x: torch.Tensor,
state_h: torch.Tensor,
state_c: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Score a batch of tokens, i.e each sample in the batch should be a
single token. For example, x = torch.tensor([[5],[10],[20]])
Args:
x (torch.Tensor):
A batch of tokens
state_h:
state h of RNN has the shape of (num_layers, bs, hidden_dim)
state_c:
state c of RNN has the shape of (num_layers, bs, hidden_dim)
Returns:
_type_: _description_
"""
embedding = self.input_embedding(x)
rnn_out, (next_h0, next_c0) = self.rnn(embedding, (state_h, state_c))
logits = self.output_linear(rnn_out)
return logits[:, 0].log_softmax(-1), next_h0, next_c0
def forward_with_state( def forward_with_state(
self, tokens, token_lens, sos_id, eos_id, blank_id, state=None self, tokens, token_lens, sos_id, eos_id, blank_id, state=None
): ):