From 57604aac34c8ff3f2398ce7ee916ddd3fe32125f Mon Sep 17 00:00:00 2001 From: KajiMaCN <827272056@qq.com> Date: Fri, 10 Feb 2023 21:28:19 +0800 Subject: [PATCH 1/3] fix tal_csasr data pre-processing (#898) --- egs/tal_csasr/ASR/prepare.sh | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/egs/tal_csasr/ASR/prepare.sh b/egs/tal_csasr/ASR/prepare.sh index d9938fa63..c5d498d74 100755 --- a/egs/tal_csasr/ASR/prepare.sh +++ b/egs/tal_csasr/ASR/prepare.sh @@ -12,9 +12,12 @@ stop_stage=100 # directories and files. If not, they will be downloaded # by this script automatically. # -# - $dl_dir/tal_csasr +# - $dl_dir/TALCS_corpus # You can find three directories:train_set, dev_set, and test_set. # You can get it from https://ai.100tal.com/dataset +# - dev_set +# - test_set +# - train_set # # - $dl_dir/musan # This directory contains the following directories downloaded from @@ -44,7 +47,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Download data" # Before you run this script, you must get the TAL_CSASR dataset # from https://ai.100tal.com/dataset - mv $dl_dir/TALCS_corpus $dl_dir/tal_csasr + if [ ! -d $dl_dir/tal_csasr/TALCS_corpus ]; then + mv $dl_dir/TALCS_corpus $dl_dir/tal_csasr + fi # If you have pre-downloaded it to /path/to/TALCS_corpus, # you can create a symlink @@ -116,7 +121,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi # Prepare text. - # Note: in Linux, you can install jq with the following command: + # Note: in Linux, you can install jq with the following command: # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 # 2. chmod +x ./jq # 3. cp jq /usr/bin From 48c2c22dbe53372e5b6565266d76283a03f6670c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 13 Feb 2023 11:44:25 +0800 Subject: [PATCH 2/3] Fix export to ncnn for lstm3 (#900) --- .../export-for-ncnn.py | 1 + .../ASR/lstm_transducer_stateless3/lstm.py | 12 +++++++- .../ASR/lstm_transducer_stateless3/train.py | 30 +++++++++++++++---- 3 files changed, 36 insertions(+), 7 deletions(-) create mode 120000 egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py new file mode 120000 index 000000000..d56cff73f --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export-for-ncnn.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/export-for-ncnn.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index 6e51b85e4..cb67fffe4 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -135,6 +135,7 @@ class RNN(EncoderInterface): dropout: float = 0.1, layer_dropout: float = 0.075, aux_layer_period: int = 0, + is_pnnx: bool = False, ) -> None: super(RNN, self).__init__() @@ -176,6 +177,8 @@ class RNN(EncoderInterface): else None, ) + self.is_pnnx = is_pnnx + def forward( self, x: torch.Tensor, @@ -216,7 +219,14 @@ class RNN(EncoderInterface): # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning # # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - lengths = (((x_lens - 3) >> 1) - 1) >> 1 + if not self.is_pnnx: + lengths = (((x_lens - 3) >> 1) - 1) >> 1 + else: + lengths1 = torch.floor((x_lens - 3) / 2) + lengths = torch.floor((lengths1 - 1) / 2) + lengths = lengths.to(x_lens) + + if not torch.jit.is_tracing(): assert x.size(0) == lengths.max().item() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index f56b4fd83..6ef4c9860 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -102,7 +102,28 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dim", type=int, default=512, - help="Encoder output dimesion.", + help="Encoder output dimension.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Decoder output dimension.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="Joiner output dimension.", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=2048, + help="Dimension of feed forward.", ) parser.add_argument( @@ -395,14 +416,10 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "dim_feedforward": 2048, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), + "is_pnnx": False, } ) @@ -419,6 +436,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, aux_layer_period=params.aux_layer_period, + is_pnnx=params.is_pnnx, ) return encoder From c102e7fbf07f25cee9baad2b739a827a356c3132 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 13 Feb 2023 12:16:43 +0800 Subject: [PATCH 3/3] more fixes for lstm3 to support exporting to ncnn (#902) --- .../ASR/lstm_transducer_stateless3/lstm.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index cb67fffe4..59a835d35 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -121,6 +121,8 @@ class RNN(EncoderInterface): Period of auxiliary layers used for random combiner during training. If set to 0, will not use the random combiner (Default). You can set a positive integer to use the random combiner, e.g., 3. + is_pnnx: + True to make this class exportable via PNNX. """ def __init__( @@ -149,7 +151,13 @@ class RNN(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_embed = Conv2dSubsampling( + num_features, + d_model, + is_pnnx=is_pnnx, + ) + + self.is_pnnx = is_pnnx self.num_encoder_layers = num_encoder_layers self.d_model = d_model @@ -177,8 +185,6 @@ class RNN(EncoderInterface): else None, ) - self.is_pnnx = is_pnnx - def forward( self, x: torch.Tensor, @@ -226,7 +232,6 @@ class RNN(EncoderInterface): lengths = torch.floor((lengths1 - 1) / 2) lengths = lengths.to(x_lens) - if not torch.jit.is_tracing(): assert x.size(0) == lengths.max().item() @@ -387,7 +392,7 @@ class RNNEncoderLayer(nn.Module): # for cell state assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) src_lstm, new_states = self.lstm(src, states) - src = src + self.dropout(src_lstm) + src = self.dropout(src_lstm) + src # feed forward module src = src + self.dropout(self.feed_forward(src)) @@ -533,6 +538,7 @@ class Conv2dSubsampling(nn.Module): layer1_channels: int = 8, layer2_channels: int = 32, layer3_channels: int = 128, + is_pnnx: bool = False, ) -> None: """ Args: @@ -545,6 +551,9 @@ class Conv2dSubsampling(nn.Module): Number of channels in layer1 layer1_channels: Number of channels in layer2 + is_pnnx: + True if we are converting the model to PNNX format. + False otherwise. """ assert in_channels >= 9 super().__init__() @@ -587,6 +596,10 @@ class Conv2dSubsampling(nn.Module): channel_dim=-1, min_positive=0.45, max_positive=0.55 ) + # ncnn supports only batch size == 1 + self.is_pnnx = is_pnnx + self.conv_out_dim = self.out.weight.shape[1] + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -600,9 +613,15 @@ class Conv2dSubsampling(nn.Module): # On entry, x is (N, T, idim) x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = self.conv(x) - # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + if torch.jit.is_tracing() and self.is_pnnx: + x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) + x = self.out(x) + else: + # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-3)//2-1))//2, odim) x = self.out_norm(x) x = self.out_balancer(x)