mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
more fixes for lstm3 to support exporting to ncnn (#902)
This commit is contained in:
parent
48c2c22dbe
commit
c102e7fbf0
@ -121,6 +121,8 @@ class RNN(EncoderInterface):
|
||||
Period of auxiliary layers used for random combiner during training.
|
||||
If set to 0, will not use the random combiner (Default).
|
||||
You can set a positive integer to use the random combiner, e.g., 3.
|
||||
is_pnnx:
|
||||
True to make this class exportable via PNNX.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -149,7 +151,13 @@ class RNN(EncoderInterface):
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (2) embedding: num_features -> d_model
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
self.encoder_embed = Conv2dSubsampling(
|
||||
num_features,
|
||||
d_model,
|
||||
is_pnnx=is_pnnx,
|
||||
)
|
||||
|
||||
self.is_pnnx = is_pnnx
|
||||
|
||||
self.num_encoder_layers = num_encoder_layers
|
||||
self.d_model = d_model
|
||||
@ -177,8 +185,6 @@ class RNN(EncoderInterface):
|
||||
else None,
|
||||
)
|
||||
|
||||
self.is_pnnx = is_pnnx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -226,7 +232,6 @@ class RNN(EncoderInterface):
|
||||
lengths = torch.floor((lengths1 - 1) / 2)
|
||||
lengths = lengths.to(x_lens)
|
||||
|
||||
|
||||
if not torch.jit.is_tracing():
|
||||
assert x.size(0) == lengths.max().item()
|
||||
|
||||
@ -387,7 +392,7 @@ class RNNEncoderLayer(nn.Module):
|
||||
# for cell state
|
||||
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
|
||||
src_lstm, new_states = self.lstm(src, states)
|
||||
src = src + self.dropout(src_lstm)
|
||||
src = self.dropout(src_lstm) + src
|
||||
|
||||
# feed forward module
|
||||
src = src + self.dropout(self.feed_forward(src))
|
||||
@ -533,6 +538,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
is_pnnx: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -545,6 +551,9 @@ class Conv2dSubsampling(nn.Module):
|
||||
Number of channels in layer1
|
||||
layer1_channels:
|
||||
Number of channels in layer2
|
||||
is_pnnx:
|
||||
True if we are converting the model to PNNX format.
|
||||
False otherwise.
|
||||
"""
|
||||
assert in_channels >= 9
|
||||
super().__init__()
|
||||
@ -587,6 +596,10 @@ class Conv2dSubsampling(nn.Module):
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||
)
|
||||
|
||||
# ncnn supports only batch size == 1
|
||||
self.is_pnnx = is_pnnx
|
||||
self.conv_out_dim = self.out.weight.shape[1]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
@ -600,9 +613,15 @@ class Conv2dSubsampling(nn.Module):
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
x = self.conv(x)
|
||||
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
|
||||
if torch.jit.is_tracing() and self.is_pnnx:
|
||||
x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
|
||||
x = self.out(x)
|
||||
else:
|
||||
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
|
||||
# Now x is of shape (N, ((T-3)//2-1))//2, odim)
|
||||
x = self.out_norm(x)
|
||||
x = self.out_balancer(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user