Convert make_pad_mask

This commit is contained in:
Fangjun Kuang 2022-08-10 16:05:31 +08:00
parent 317f47fb37
commit 89124a59ac
3 changed files with 88 additions and 2 deletions

View File

@ -35,6 +35,12 @@ from torch import Tensor, nn
from icefall.utils import make_pad_mask, subsequent_chunk_mask
class MakePadMask(nn.Module):
def forward(self, lengths: Tensor) -> Tensor:
"""See doc for :func:`make_pad_mask`"""
return make_pad_mask(lengths)
class Conformer(EncoderInterface):
"""
Args:
@ -111,6 +117,7 @@ class Conformer(EncoderInterface):
self.num_left_chunks = num_left_chunks
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
self.make_pad_mask = MakePadMask()
encoder_layer = ConformerEncoderLayer(
d_model,
@ -158,7 +165,7 @@ class Conformer(EncoderInterface):
if not torch.jit.is_tracing():
assert x.size(0) == lengths.max().item()
src_key_padding_mask = make_pad_mask(lengths)
src_key_padding_mask = self.make_pad_mask(lengths)
if self.dynamic_chunk_training:
assert (
@ -835,7 +842,7 @@ class RelPositionalEncoding(torch.nn.Module):
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.register_buffer("pe", pe.to(device=x.device, dtype=x.dtype))
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(
self,

View File

@ -0,0 +1,42 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
from conformer import MakePadMask
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.make_pad_mask = MakePadMask()
def forward(self, x: torch.Tensor):
"""
Args:
x:
(N,)
"""
src_key_padding_mask = self.make_pad_mask(x)
return src_key_padding_mask
def generate_pt():
f = Foo()
f.eval()
x = torch.tensor([1, 3, 5])
y = f(x)
print("y.shape", y.shape)
print(y)
m = torch.jit.trace(f, x)
m.save("foo/make_pad_mask.pt")
print(m.graph)
def main():
generate_pt()
if __name__ == "__main__":
torch.manual_seed(20220809)
main()

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python3
import ncnn
import torch
@torch.no_grad()
def main():
x = torch.tensor([1, 3, 5, 8])
m = torch.jit.load("foo/make_pad_mask.pt")
t = m(x)
print(t.shape)
print(t)
param = "foo/make_pad_mask.ncnn.param"
model = "foo/make_pad_mask.ncnn.bin"
with ncnn.Net() as net:
net.load_param(param)
net.load_model(model)
with net.create_extractor() as ex:
x = x.to(torch.int32)
ex.input("in0", ncnn.Mat(x.numpy()).clone())
ret, ncnn_out0 = ex.extract("out0")
assert ret == 0, ret
n = ncnn_out0.numpy("i")
print(n.shape)
n = torch.from_numpy(n).to(torch.bool)
print(n)
assert torch.equal(t, n), (t, n)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
torch.manual_seed(202208010)
main()