Export stateless pruned transducer to jit script module

This commit is contained in:
pkufool 2022-04-06 10:31:28 +08:00
parent cb3ba16f2b
commit d5437ba623
2 changed files with 54 additions and 2 deletions

View File

@ -116,8 +116,6 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))

View File

@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
from typing import Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -58,6 +60,12 @@ class Transducer(nn.Module):
self.decoder = decoder self.decoder = decoder
self.joiner = joiner self.joiner = joiner
self.subsampling_factor = self.encoder.subsampling_factor
self.context_size = self.decoder.context_size
self.blank_id = self.decoder.blank_id
self.vocab_size = self.decoder.vocab_size
@torch.jit.ignore
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -167,3 +175,49 @@ class Transducer(nn.Module):
) )
return (simple_loss, pruned_loss) return (simple_loss, pruned_loss)
@torch.jit.export
def encoder_forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
return self.encoder(x, x_lens)
@torch.jit.export
def decoder_forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U) with blank prepended.
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
return self.decoder(y, False)
@torch.jit.export
def joiner_forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, 1, C).
decoder_out:
Output from the decoder. Its shape is (N, 1, C).
Returns:
Return a tensor of shape (N, C).
"""
logits = self.joiner(encoder_out.unsqueeze(2), decoder_out.unsqueeze(1))
return logits.squeeze(1).squeeze(1)