Fix export to ncnn for lstm3 (#900)

This commit is contained in:
Fangjun Kuang 2023-02-13 11:44:25 +08:00 committed by GitHub
parent 57604aac34
commit 48c2c22dbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 7 deletions

View File

@ -0,0 +1 @@
../lstm_transducer_stateless2/export-for-ncnn.py

View File

@ -135,6 +135,7 @@ class RNN(EncoderInterface):
dropout: float = 0.1, dropout: float = 0.1,
layer_dropout: float = 0.075, layer_dropout: float = 0.075,
aux_layer_period: int = 0, aux_layer_period: int = 0,
is_pnnx: bool = False,
) -> None: ) -> None:
super(RNN, self).__init__() super(RNN, self).__init__()
@ -176,6 +177,8 @@ class RNN(EncoderInterface):
else None, else None,
) )
self.is_pnnx = is_pnnx
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -216,7 +219,14 @@ class RNN(EncoderInterface):
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
# #
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 3) >> 1) - 1) >> 1 if not self.is_pnnx:
lengths = (((x_lens - 3) >> 1) - 1) >> 1
else:
lengths1 = torch.floor((x_lens - 3) / 2)
lengths = torch.floor((lengths1 - 1) / 2)
lengths = lengths.to(x_lens)
if not torch.jit.is_tracing(): if not torch.jit.is_tracing():
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()

View File

@ -102,7 +102,28 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--encoder-dim", "--encoder-dim",
type=int, type=int,
default=512, default=512,
help="Encoder output dimesion.", help="Encoder output dimension.",
)
parser.add_argument(
"--decoder-dim",
type=int,
default=512,
help="Decoder output dimension.",
)
parser.add_argument(
"--joiner-dim",
type=int,
default=512,
help="Joiner output dimension.",
)
parser.add_argument(
"--dim-feedforward",
type=int,
default=2048,
help="Dimension of feed forward.",
) )
parser.add_argument( parser.add_argument(
@ -395,14 +416,10 @@ def get_params() -> AttributeDict:
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
"dim_feedforward": 2048,
# parameters for decoder
"decoder_dim": 512,
# parameters for joiner
"joiner_dim": 512,
# parameters for Noam # parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate "model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(), "env_info": get_env_info(),
"is_pnnx": False,
} }
) )
@ -419,6 +436,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
aux_layer_period=params.aux_layer_period, aux_layer_period=params.aux_layer_period,
is_pnnx=params.is_pnnx,
) )
return encoder return encoder