diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py new file mode 120000 index 000000000..d56cff73f --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/export-for-ncnn.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index 6e51b85e4..cb67fffe4 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -135,6 +135,7 @@ class RNN(EncoderInterface): dropout: float = 0.1, layer_dropout: float = 0.075, aux_layer_period: int = 0, + is_pnnx: bool = False, ) -> None: super(RNN, self).__init__() @@ -176,6 +177,8 @@ class RNN(EncoderInterface): else None, ) + self.is_pnnx = is_pnnx + def forward( self, x: torch.Tensor, @@ -216,7 +219,14 @@ class RNN(EncoderInterface): # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning # # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - lengths = (((x_lens - 3) >> 1) - 1) >> 1 + if not self.is_pnnx: + lengths = (((x_lens - 3) >> 1) - 1) >> 1 + else: + lengths1 = torch.floor((x_lens - 3) / 2) + lengths = torch.floor((lengths1 - 1) / 2) + lengths = lengths.to(x_lens) + + if not torch.jit.is_tracing(): assert x.size(0) == lengths.max().item() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index f56b4fd83..6ef4c9860 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -102,7 +102,28 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dim", type=int, default=512, - help="Encoder output dimesion.", + help="Encoder output dimension.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Decoder output dimension.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="Joiner output dimension.", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=2048, + help="Dimension of feed forward.", ) parser.add_argument( @@ -395,14 +416,10 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "dim_feedforward": 2048, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), + "is_pnnx": False, } ) @@ -419,6 +436,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, aux_layer_period=params.aux_layer_period, + is_pnnx=params.is_pnnx, ) return encoder