mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
export score_token interface for onnx-runtime (#1050)
This commit is contained in:
parent
6c326427a0
commit
44d016e4a7
@ -103,17 +103,14 @@ def main():
|
||||
x = torch.randint(
|
||||
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
|
||||
)
|
||||
y = torch.randint(
|
||||
low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64
|
||||
)
|
||||
h0 = torch.rand(num_layers, n, hidden_size)
|
||||
c0 = torch.rand(num_layers, n, hidden_size)
|
||||
|
||||
torch_nll, torch_h0, torch_c0 = torch_model.streaming_forward(x, y, h0, c0)
|
||||
onnx_nll, onnx_h0, onnx_c0 = onnx_model(x, y, h0, c0)
|
||||
torch_log_prob, torch_h0, torch_c0 = torch_model.score_token_onnx(x, h0, c0)
|
||||
onnx_log_prob, onnx_h0, onnx_c0 = onnx_model(x, h0, c0)
|
||||
|
||||
for torch_v, onnx_v in zip(
|
||||
(torch_nll, torch_h0, torch_c0), (onnx_nll, onnx_h0, onnx_c0)
|
||||
(torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0)
|
||||
):
|
||||
|
||||
assert torch.allclose(torch_v, onnx_v, atol=1e-5), (
|
||||
|
@ -235,7 +235,6 @@ def export_with_state(
|
||||
embedding_dim = model.embedding_dim
|
||||
|
||||
x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
|
||||
y = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64)
|
||||
h0 = torch.zeros(num_layers, N, hidden_size)
|
||||
c0 = torch.zeros(num_layers, N, hidden_size)
|
||||
|
||||
@ -252,18 +251,17 @@ def export_with_state(
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(x, y, h0, c0),
|
||||
(x, h0, c0),
|
||||
filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "y", "h0", "c0"],
|
||||
output_names=["nll", "next_h0", "next_c0"],
|
||||
input_names=["x", "h0", "c0"],
|
||||
output_names=["log_softmax", "next_h0", "next_c0"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "L"},
|
||||
"y": {0: "N", 1: "L"},
|
||||
"h0": {1: "N"},
|
||||
"c0": {1: "N"},
|
||||
"nll": {0: "N"},
|
||||
"log_softmax": {0: "N"},
|
||||
"next_h0": {1: "N"},
|
||||
"next_c0": {1: "N"},
|
||||
},
|
||||
@ -372,7 +370,7 @@ def main():
|
||||
|
||||
# now for streaming export
|
||||
saved_forward = model.__class__.forward
|
||||
model.__class__.forward = model.__class__.streaming_forward
|
||||
model.__class__.forward = model.__class__.score_token_onnx
|
||||
streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx"
|
||||
export_with_state(
|
||||
model=model,
|
||||
|
@ -182,8 +182,8 @@ def main():
|
||||
if params.jit:
|
||||
logging.info("Using torch.jit.script")
|
||||
|
||||
model.__class__.streaming_forward = torch.jit.export(
|
||||
model.__class__.streaming_forward
|
||||
model.__class__.score_token_onnx = torch.jit.export(
|
||||
model.__class__.score_token_onnx
|
||||
)
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
|
@ -234,6 +234,33 @@ class RnnLmModel(torch.nn.Module):
|
||||
|
||||
return logits[:, 0].log_softmax(-1), states
|
||||
|
||||
def score_token_onnx(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
state_h: torch.Tensor,
|
||||
state_c: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Score a batch of tokens, i.e each sample in the batch should be a
|
||||
single token. For example, x = torch.tensor([[5],[10],[20]])
|
||||
|
||||
|
||||
Args:
|
||||
x (torch.Tensor):
|
||||
A batch of tokens
|
||||
state_h:
|
||||
state h of RNN has the shape of (num_layers, bs, hidden_dim)
|
||||
state_c:
|
||||
state c of RNN has the shape of (num_layers, bs, hidden_dim)
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
embedding = self.input_embedding(x)
|
||||
rnn_out, (next_h0, next_c0) = self.rnn(embedding, (state_h, state_c))
|
||||
logits = self.output_linear(rnn_out)
|
||||
|
||||
return logits[:, 0].log_softmax(-1), next_h0, next_c0
|
||||
|
||||
def forward_with_state(
|
||||
self, tokens, token_lens, sos_id, eos_id, blank_id, state=None
|
||||
):
|
||||
|
Loading…
x
Reference in New Issue
Block a user