mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix export to ncnn for lstm3 (#900)
This commit is contained in:
parent
57604aac34
commit
48c2c22dbe
@ -0,0 +1 @@
|
|||||||
|
../lstm_transducer_stateless2/export-for-ncnn.py
|
@ -135,6 +135,7 @@ class RNN(EncoderInterface):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.075,
|
layer_dropout: float = 0.075,
|
||||||
aux_layer_period: int = 0,
|
aux_layer_period: int = 0,
|
||||||
|
is_pnnx: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(RNN, self).__init__()
|
super(RNN, self).__init__()
|
||||||
|
|
||||||
@ -176,6 +177,8 @@ class RNN(EncoderInterface):
|
|||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.is_pnnx = is_pnnx
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -216,7 +219,14 @@ class RNN(EncoderInterface):
|
|||||||
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
|
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
|
||||||
#
|
#
|
||||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||||
lengths = (((x_lens - 3) >> 1) - 1) >> 1
|
if not self.is_pnnx:
|
||||||
|
lengths = (((x_lens - 3) >> 1) - 1) >> 1
|
||||||
|
else:
|
||||||
|
lengths1 = torch.floor((x_lens - 3) / 2)
|
||||||
|
lengths = torch.floor((lengths1 - 1) / 2)
|
||||||
|
lengths = lengths.to(x_lens)
|
||||||
|
|
||||||
|
|
||||||
if not torch.jit.is_tracing():
|
if not torch.jit.is_tracing():
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
|
@ -102,7 +102,28 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
"--encoder-dim",
|
"--encoder-dim",
|
||||||
type=int,
|
type=int,
|
||||||
default=512,
|
default=512,
|
||||||
help="Encoder output dimesion.",
|
help="Encoder output dimension.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-dim",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Decoder output dimension.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-dim",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Joiner output dimension.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dim-feedforward",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Dimension of feed forward.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -395,14 +416,10 @@ def get_params() -> AttributeDict:
|
|||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"dim_feedforward": 2048,
|
|
||||||
# parameters for decoder
|
|
||||||
"decoder_dim": 512,
|
|
||||||
# parameters for joiner
|
|
||||||
"joiner_dim": 512,
|
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
|
"is_pnnx": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -419,6 +436,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
aux_layer_period=params.aux_layer_period,
|
aux_layer_period=params.aux_layer_period,
|
||||||
|
is_pnnx=params.is_pnnx,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user