mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Convert make_pad_mask
This commit is contained in:
parent
317f47fb37
commit
89124a59ac
@ -35,6 +35,12 @@ from torch import Tensor, nn
|
|||||||
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
||||||
|
|
||||||
|
|
||||||
|
class MakePadMask(nn.Module):
|
||||||
|
def forward(self, lengths: Tensor) -> Tensor:
|
||||||
|
"""See doc for :func:`make_pad_mask`"""
|
||||||
|
return make_pad_mask(lengths)
|
||||||
|
|
||||||
|
|
||||||
class Conformer(EncoderInterface):
|
class Conformer(EncoderInterface):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -111,6 +117,7 @@ class Conformer(EncoderInterface):
|
|||||||
self.num_left_chunks = num_left_chunks
|
self.num_left_chunks = num_left_chunks
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
|
self.make_pad_mask = MakePadMask()
|
||||||
|
|
||||||
encoder_layer = ConformerEncoderLayer(
|
encoder_layer = ConformerEncoderLayer(
|
||||||
d_model,
|
d_model,
|
||||||
@ -158,7 +165,7 @@ class Conformer(EncoderInterface):
|
|||||||
if not torch.jit.is_tracing():
|
if not torch.jit.is_tracing():
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = self.make_pad_mask(lengths)
|
||||||
|
|
||||||
if self.dynamic_chunk_training:
|
if self.dynamic_chunk_training:
|
||||||
assert (
|
assert (
|
||||||
@ -835,7 +842,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||||
self.register_buffer("pe", pe.to(device=x.device, dtype=x.dtype))
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -0,0 +1,42 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from conformer import MakePadMask
|
||||||
|
|
||||||
|
|
||||||
|
class Foo(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.make_pad_mask = MakePadMask()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
(N,)
|
||||||
|
"""
|
||||||
|
src_key_padding_mask = self.make_pad_mask(x)
|
||||||
|
return src_key_padding_mask
|
||||||
|
|
||||||
|
|
||||||
|
def generate_pt():
|
||||||
|
f = Foo()
|
||||||
|
f.eval()
|
||||||
|
x = torch.tensor([1, 3, 5])
|
||||||
|
y = f(x)
|
||||||
|
print("y.shape", y.shape)
|
||||||
|
print(y)
|
||||||
|
m = torch.jit.trace(f, x)
|
||||||
|
m.save("foo/make_pad_mask.pt")
|
||||||
|
print(m.graph)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
generate_pt()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(20220809)
|
||||||
|
main()
|
37
egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn_make_pad_mask.py
Executable file
37
egs/librispeech/ASR/pruned_transducer_stateless3/test_ncnn_make_pad_mask.py
Executable file
@ -0,0 +1,37 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import ncnn
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
x = torch.tensor([1, 3, 5, 8])
|
||||||
|
m = torch.jit.load("foo/make_pad_mask.pt")
|
||||||
|
t = m(x)
|
||||||
|
print(t.shape)
|
||||||
|
print(t)
|
||||||
|
|
||||||
|
param = "foo/make_pad_mask.ncnn.param"
|
||||||
|
model = "foo/make_pad_mask.ncnn.bin"
|
||||||
|
with ncnn.Net() as net:
|
||||||
|
net.load_param(param)
|
||||||
|
net.load_model(model)
|
||||||
|
with net.create_extractor() as ex:
|
||||||
|
x = x.to(torch.int32)
|
||||||
|
ex.input("in0", ncnn.Mat(x.numpy()).clone())
|
||||||
|
|
||||||
|
ret, ncnn_out0 = ex.extract("out0")
|
||||||
|
assert ret == 0, ret
|
||||||
|
n = ncnn_out0.numpy("i")
|
||||||
|
print(n.shape)
|
||||||
|
n = torch.from_numpy(n).to(torch.bool)
|
||||||
|
print(n)
|
||||||
|
assert torch.equal(t, n), (t, n)
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(202208010)
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user