mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
export score_token interface for onnx-runtime (#1050)
This commit is contained in:
parent
6c326427a0
commit
44d016e4a7
@ -103,17 +103,14 @@ def main():
|
|||||||
x = torch.randint(
|
x = torch.randint(
|
||||||
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
|
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
|
||||||
)
|
)
|
||||||
y = torch.randint(
|
|
||||||
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
|
|
||||||
)
|
|
||||||
h0 = torch.rand(num_layers, n, hidden_size)
|
h0 = torch.rand(num_layers, n, hidden_size)
|
||||||
c0 = torch.rand(num_layers, n, hidden_size)
|
c0 = torch.rand(num_layers, n, hidden_size)
|
||||||
|
|
||||||
torch_nll, torch_h0, torch_c0 = torch_model.streaming_forward(x, y, h0, c0)
|
torch_log_prob, torch_h0, torch_c0 = torch_model.score_token_onnx(x, h0, c0)
|
||||||
onnx_nll, onnx_h0, onnx_c0 = onnx_model(x, y, h0, c0)
|
onnx_log_prob, onnx_h0, onnx_c0 = onnx_model(x, h0, c0)
|
||||||
|
|
||||||
for torch_v, onnx_v in zip(
|
for torch_v, onnx_v in zip(
|
||||||
(torch_nll, torch_h0, torch_c0), (onnx_nll, onnx_h0, onnx_c0)
|
(torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0)
|
||||||
):
|
):
|
||||||
|
|
||||||
assert torch.allclose(torch_v, onnx_v, atol=1e-5), (
|
assert torch.allclose(torch_v, onnx_v, atol=1e-5), (
|
||||||
|
@ -235,7 +235,6 @@ def export_with_state(
|
|||||||
embedding_dim = model.embedding_dim
|
embedding_dim = model.embedding_dim
|
||||||
|
|
||||||
x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
|
x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
|
||||||
y = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
|
|
||||||
h0 = torch.zeros(num_layers, N, hidden_size)
|
h0 = torch.zeros(num_layers, N, hidden_size)
|
||||||
c0 = torch.zeros(num_layers, N, hidden_size)
|
c0 = torch.zeros(num_layers, N, hidden_size)
|
||||||
|
|
||||||
@ -252,18 +251,17 @@ def export_with_state(
|
|||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
model,
|
model,
|
||||||
(x, y, h0, c0),
|
(x, h0, c0),
|
||||||
filename,
|
filename,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
input_names=["x", "y", "h0", "c0"],
|
input_names=["x", "h0", "c0"],
|
||||||
output_names=["nll", "next_h0", "next_c0"],
|
output_names=["log_softmax", "next_h0", "next_c0"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"x": {0: "N", 1: "L"},
|
"x": {0: "N", 1: "L"},
|
||||||
"y": {0: "N", 1: "L"},
|
|
||||||
"h0": {1: "N"},
|
"h0": {1: "N"},
|
||||||
"c0": {1: "N"},
|
"c0": {1: "N"},
|
||||||
"nll": {0: "N"},
|
"log_softmax": {0: "N"},
|
||||||
"next_h0": {1: "N"},
|
"next_h0": {1: "N"},
|
||||||
"next_c0": {1: "N"},
|
"next_c0": {1: "N"},
|
||||||
},
|
},
|
||||||
@ -372,7 +370,7 @@ def main():
|
|||||||
|
|
||||||
# now for streaming export
|
# now for streaming export
|
||||||
saved_forward = model.__class__.forward
|
saved_forward = model.__class__.forward
|
||||||
model.__class__.forward = model.__class__.streaming_forward
|
model.__class__.forward = model.__class__.score_token_onnx
|
||||||
streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx"
|
streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx"
|
||||||
export_with_state(
|
export_with_state(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -182,8 +182,8 @@ def main():
|
|||||||
if params.jit:
|
if params.jit:
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
|
|
||||||
model.__class__.streaming_forward = torch.jit.export(
|
model.__class__.score_token_onnx = torch.jit.export(
|
||||||
model.__class__.streaming_forward
|
model.__class__.score_token_onnx
|
||||||
)
|
)
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -234,6 +234,33 @@ class RnnLmModel(torch.nn.Module):
|
|||||||
|
|
||||||
return logits[:, 0].log_softmax(-1), states
|
return logits[:, 0].log_softmax(-1), states
|
||||||
|
|
||||||
|
def score_token_onnx(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
state_h: torch.Tensor,
|
||||||
|
state_c: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Score a batch of tokens, i.e each sample in the batch should be a
|
||||||
|
single token. For example, x = torch.tensor([[5],[10],[20]])
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor):
|
||||||
|
A batch of tokens
|
||||||
|
state_h:
|
||||||
|
state h of RNN has the shape of (num_layers, bs, hidden_dim)
|
||||||
|
state_c:
|
||||||
|
state c of RNN has the shape of (num_layers, bs, hidden_dim)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
embedding = self.input_embedding(x)
|
||||||
|
rnn_out, (next_h0, next_c0) = self.rnn(embedding, (state_h, state_c))
|
||||||
|
logits = self.output_linear(rnn_out)
|
||||||
|
|
||||||
|
return logits[:, 0].log_softmax(-1), next_h0, next_c0
|
||||||
|
|
||||||
def forward_with_state(
|
def forward_with_state(
|
||||||
self, tokens, token_lens, sos_id, eos_id, blank_id, state=None
|
self, tokens, token_lens, sos_id, eos_id, blank_id, state=None
|
||||||
):
|
):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user