diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 8db9054a5..625ea10d6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -35,6 +35,12 @@ from torch import Tensor, nn from icefall.utils import make_pad_mask, subsequent_chunk_mask +class MakePadMask(nn.Module): + def forward(self, lengths: Tensor) -> Tensor: + """See doc for :func:`make_pad_mask`""" + return make_pad_mask(lengths) + + class Conformer(EncoderInterface): """ Args: @@ -111,6 +117,7 @@ class Conformer(EncoderInterface): self.num_left_chunks = num_left_chunks self.encoder_pos = RelPositionalEncoding(d_model, dropout) + self.make_pad_mask = MakePadMask() encoder_layer = ConformerEncoderLayer( d_model, @@ -158,7 +165,7 @@ class Conformer(EncoderInterface): if not torch.jit.is_tracing(): assert x.size(0) == lengths.max().item() - src_key_padding_mask = make_pad_mask(lengths) + src_key_padding_mask = self.make_pad_mask(lengths) if self.dynamic_chunk_training: assert ( @@ -835,7 +842,7 @@ class RelPositionalEncoding(torch.nn.Module): pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) pe_negative = pe_negative[1:].unsqueeze(0) pe = torch.cat([pe_positive, pe_negative], dim=1) - self.register_buffer("pe", pe.to(device=x.device, dtype=x.dtype)) + self.pe = pe.to(device=x.device, dtype=x.dtype) def forward( self, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_generate_make_pad_mask.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_generate_make_pad_mask.py new file mode 100755 index 000000000..83ce3dd1d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_generate_make_pad_mask.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 + + +import torch +import torch.nn as nn +from conformer import MakePadMask + + +class Foo(nn.Module): + def __init__(self): + super().__init__() + self.make_pad_mask = MakePadMask() + + def forward(self, x: torch.Tensor): + """ + Args: + x: + (N,) + """ + src_key_padding_mask = self.make_pad_mask(x) + return src_key_padding_mask + + +def generate_pt(): + f = Foo() + f.eval() + x = torch.tensor([1, 3, 5]) + y = f(x) + print("y.shape", y.shape) + print(y) + m = torch.jit.trace(f, x) + m.save("foo/make_pad_mask.pt") + print(m.graph) + + +def main(): + generate_pt() + + +if __name__ == "__main__": + torch.manual_seed(20220809) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn_make_pad_mask.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn_make_pad_mask.py new file mode 100755 index 000000000..ff321e730 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn_make_pad_mask.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 + +import ncnn +import torch + + +@torch.no_grad() +def main(): + x = torch.tensor([1, 3, 5, 8]) + m = torch.jit.load("foo/make_pad_mask.pt") + t = m(x) + print(t.shape) + print(t) + + param = "foo/make_pad_mask.ncnn.param" + model = "foo/make_pad_mask.ncnn.bin" + with ncnn.Net() as net: + net.load_param(param) + net.load_model(model) + with net.create_extractor() as ex: + x = x.to(torch.int32) + ex.input("in0", ncnn.Mat(x.numpy()).clone()) + + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + n = ncnn_out0.numpy("i") + print(n.shape) + n = torch.from_numpy(n).to(torch.bool) + print(n) + assert torch.equal(t, n), (t, n) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +if __name__ == "__main__": + torch.manual_seed(202208010) + main()