From d5437ba62355eafe673c45f5c40052f3ac30f946 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 6 Apr 2022 10:31:28 +0800 Subject: [PATCH] Export stateless pruned transducer to jit script module --- .../ASR/pruned_transducer_stateless/export.py | 2 - .../ASR/pruned_transducer_stateless/model.py | 54 +++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index 7d2a07817..3f054fac9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -116,8 +116,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 2f019bcdb..7992c1de3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -15,6 +15,8 @@ # limitations under the License. +from typing import Tuple + import k2 import torch import torch.nn as nn @@ -58,6 +60,12 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner + self.subsampling_factor = self.encoder.subsampling_factor + self.context_size = self.decoder.context_size + self.blank_id = self.decoder.blank_id + self.vocab_size = self.decoder.vocab_size + + @torch.jit.ignore def forward( self, x: torch.Tensor, @@ -167,3 +175,49 @@ class Transducer(nn.Module): ) return (simple_loss, pruned_loss) + + @torch.jit.export + def encoder_forward( + self, x: torch.Tensor, x_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + Returns: + Return a tuple containing 2 tensors: + - logits, its shape is (batch_size, output_seq_len, output_dim) + - logit_lens, a tensor of shape (batch_size,) containing the number + of frames in `logits` before padding. + """ + return self.encoder(x, x_lens) + + @torch.jit.export + def decoder_forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U) with blank prepended. + Returns: + Return a tensor of shape (N, U, embedding_dim). + """ + return self.decoder(y, False) + + @torch.jit.export + def joiner_forward( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, 1, C). + decoder_out: + Output from the decoder. Its shape is (N, 1, C). + Returns: + Return a tensor of shape (N, C). + """ + logits = self.joiner(encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)) + return logits.squeeze(1).squeeze(1)