more fixes for lstm3 to support exporting to ncnn (#902)

This commit is contained in:
Fangjun Kuang 2023-02-13 12:16:43 +08:00 committed by GitHub
parent 48c2c22dbe
commit c102e7fbf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -121,6 +121,8 @@ class RNN(EncoderInterface):
Period of auxiliary layers used for random combiner during training. Period of auxiliary layers used for random combiner during training.
If set to 0, will not use the random combiner (Default). If set to 0, will not use the random combiner (Default).
You can set a positive integer to use the random combiner, e.g., 3. You can set a positive integer to use the random combiner, e.g., 3.
is_pnnx:
True to make this class exportable via PNNX.
""" """
def __init__( def __init__(
@ -149,7 +151,13 @@ class RNN(EncoderInterface):
# That is, it does two things simultaneously: # That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor # (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model # (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_embed = Conv2dSubsampling(
num_features,
d_model,
is_pnnx=is_pnnx,
)
self.is_pnnx = is_pnnx
self.num_encoder_layers = num_encoder_layers self.num_encoder_layers = num_encoder_layers
self.d_model = d_model self.d_model = d_model
@ -177,8 +185,6 @@ class RNN(EncoderInterface):
else None, else None,
) )
self.is_pnnx = is_pnnx
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -226,7 +232,6 @@ class RNN(EncoderInterface):
lengths = torch.floor((lengths1 - 1) / 2) lengths = torch.floor((lengths1 - 1) / 2)
lengths = lengths.to(x_lens) lengths = lengths.to(x_lens)
if not torch.jit.is_tracing(): if not torch.jit.is_tracing():
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
@ -387,7 +392,7 @@ class RNNEncoderLayer(nn.Module):
# for cell state # for cell state
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
src_lstm, new_states = self.lstm(src, states) src_lstm, new_states = self.lstm(src, states)
src = src + self.dropout(src_lstm) src = self.dropout(src_lstm) + src
# feed forward module # feed forward module
src = src + self.dropout(self.feed_forward(src)) src = src + self.dropout(self.feed_forward(src))
@ -533,6 +538,7 @@ class Conv2dSubsampling(nn.Module):
layer1_channels: int = 8, layer1_channels: int = 8,
layer2_channels: int = 32, layer2_channels: int = 32,
layer3_channels: int = 128, layer3_channels: int = 128,
is_pnnx: bool = False,
) -> None: ) -> None:
""" """
Args: Args:
@ -545,6 +551,9 @@ class Conv2dSubsampling(nn.Module):
Number of channels in layer1 Number of channels in layer1
layer1_channels: layer1_channels:
Number of channels in layer2 Number of channels in layer2
is_pnnx:
True if we are converting the model to PNNX format.
False otherwise.
""" """
assert in_channels >= 9 assert in_channels >= 9
super().__init__() super().__init__()
@ -587,6 +596,10 @@ class Conv2dSubsampling(nn.Module):
channel_dim=-1, min_positive=0.45, max_positive=0.55 channel_dim=-1, min_positive=0.45, max_positive=0.55
) )
# ncnn supports only batch size == 1
self.is_pnnx = is_pnnx
self.conv_out_dim = self.out.weight.shape[1]
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x. """Subsample x.
@ -600,9 +613,15 @@ class Conv2dSubsampling(nn.Module):
# On entry, x is (N, T, idim) # On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x) x = self.conv(x)
if torch.jit.is_tracing() and self.is_pnnx:
x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
x = self.out(x)
else:
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
b, c, t, f = x.size() b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-3)//2-1))//2, odim) # Now x is of shape (N, ((T-3)//2-1))//2, odim)
x = self.out_norm(x) x = self.out_norm(x)
x = self.out_balancer(x) x = self.out_balancer(x)