mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
Export stateless pruned transducer to jit script module
This commit is contained in:
parent
cb3ba16f2b
commit
d5437ba623
@ -116,8 +116,6 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
assert args.jit is False, "Support torchscript will be added later"
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
|
@ -15,6 +15,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -58,6 +60,12 @@ class Transducer(nn.Module):
|
|||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.joiner = joiner
|
self.joiner = joiner
|
||||||
|
|
||||||
|
self.subsampling_factor = self.encoder.subsampling_factor
|
||||||
|
self.context_size = self.decoder.context_size
|
||||||
|
self.blank_id = self.decoder.blank_id
|
||||||
|
self.vocab_size = self.decoder.vocab_size
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -167,3 +175,49 @@ class Transducer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return (simple_loss, pruned_loss)
|
return (simple_loss, pruned_loss)
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def encoder_forward(
|
||||||
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||||
|
x_lens:
|
||||||
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
|
`x` before padding.
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing 2 tensors:
|
||||||
|
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||||
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||||
|
of frames in `logits` before padding.
|
||||||
|
"""
|
||||||
|
return self.encoder(x, x_lens)
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def decoder_forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y:
|
||||||
|
A 2-D tensor of shape (N, U) with blank prepended.
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (N, U, embedding_dim).
|
||||||
|
"""
|
||||||
|
return self.decoder(y, False)
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def joiner_forward(
|
||||||
|
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, 1, C).
|
||||||
|
decoder_out:
|
||||||
|
Output from the decoder. Its shape is (N, 1, C).
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (N, C).
|
||||||
|
"""
|
||||||
|
logits = self.joiner(encoder_out.unsqueeze(2), decoder_out.unsqueeze(1))
|
||||||
|
return logits.squeeze(1).squeeze(1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user