mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
Export stateless pruned transducer to jit script module
This commit is contained in:
parent
cb3ba16f2b
commit
d5437ba623
@ -116,8 +116,6 @@ def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
assert args.jit is False, "Support torchscript will be added later"
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
|
@ -15,6 +15,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -58,6 +60,12 @@ class Transducer(nn.Module):
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.subsampling_factor = self.encoder.subsampling_factor
|
||||
self.context_size = self.decoder.context_size
|
||||
self.blank_id = self.decoder.blank_id
|
||||
self.vocab_size = self.decoder.vocab_size
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -167,3 +175,49 @@ class Transducer(nn.Module):
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
|
||||
@torch.jit.export
|
||||
def encoder_forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`x` before padding.
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `logits` before padding.
|
||||
"""
|
||||
return self.encoder(x, x_lens)
|
||||
|
||||
@torch.jit.export
|
||||
def decoder_forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U) with blank prepended.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, embedding_dim).
|
||||
"""
|
||||
return self.decoder(y, False)
|
||||
|
||||
@torch.jit.export
|
||||
def joiner_forward(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, 1, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, 1, C).
|
||||
Returns:
|
||||
Return a tensor of shape (N, C).
|
||||
"""
|
||||
logits = self.joiner(encoder_out.unsqueeze(2), decoder_out.unsqueeze(1))
|
||||
return logits.squeeze(1).squeeze(1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user