Merge branch 'master' of github.com:marcoyang1998/icefall into add_lstm_transducer

This commit is contained in:
marcoyang 2023-02-13 12:47:34 +08:00
commit b3fa59d68a
4 changed files with 68 additions and 15 deletions

View File

@ -0,0 +1 @@
../lstm_transducer_stateless2/export-for-ncnn.py

View File

@ -121,6 +121,8 @@ class RNN(EncoderInterface):
Period of auxiliary layers used for random combiner during training. Period of auxiliary layers used for random combiner during training.
If set to 0, will not use the random combiner (Default). If set to 0, will not use the random combiner (Default).
You can set a positive integer to use the random combiner, e.g., 3. You can set a positive integer to use the random combiner, e.g., 3.
is_pnnx:
True to make this class exportable via PNNX.
""" """
def __init__( def __init__(
@ -135,6 +137,7 @@ class RNN(EncoderInterface):
dropout: float = 0.1, dropout: float = 0.1,
layer_dropout: float = 0.075, layer_dropout: float = 0.075,
aux_layer_period: int = 0, aux_layer_period: int = 0,
is_pnnx: bool = False,
) -> None: ) -> None:
super(RNN, self).__init__() super(RNN, self).__init__()
@ -148,7 +151,13 @@ class RNN(EncoderInterface):
# That is, it does two things simultaneously: # That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor # (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model # (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_embed = Conv2dSubsampling(
num_features,
d_model,
is_pnnx=is_pnnx,
)
self.is_pnnx = is_pnnx
self.num_encoder_layers = num_encoder_layers self.num_encoder_layers = num_encoder_layers
self.d_model = d_model self.d_model = d_model
@ -216,7 +225,13 @@ class RNN(EncoderInterface):
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
# #
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 3) >> 1) - 1) >> 1 if not self.is_pnnx:
lengths = (((x_lens - 3) >> 1) - 1) >> 1
else:
lengths1 = torch.floor((x_lens - 3) / 2)
lengths = torch.floor((lengths1 - 1) / 2)
lengths = lengths.to(x_lens)
if not torch.jit.is_tracing(): if not torch.jit.is_tracing():
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
@ -377,7 +392,7 @@ class RNNEncoderLayer(nn.Module):
# for cell state # for cell state
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
src_lstm, new_states = self.lstm(src, states) src_lstm, new_states = self.lstm(src, states)
src = src + self.dropout(src_lstm) src = self.dropout(src_lstm) + src
# feed forward module # feed forward module
src = src + self.dropout(self.feed_forward(src)) src = src + self.dropout(self.feed_forward(src))
@ -523,6 +538,7 @@ class Conv2dSubsampling(nn.Module):
layer1_channels: int = 8, layer1_channels: int = 8,
layer2_channels: int = 32, layer2_channels: int = 32,
layer3_channels: int = 128, layer3_channels: int = 128,
is_pnnx: bool = False,
) -> None: ) -> None:
""" """
Args: Args:
@ -535,6 +551,9 @@ class Conv2dSubsampling(nn.Module):
Number of channels in layer1 Number of channels in layer1
layer1_channels: layer1_channels:
Number of channels in layer2 Number of channels in layer2
is_pnnx:
True if we are converting the model to PNNX format.
False otherwise.
""" """
assert in_channels >= 9 assert in_channels >= 9
super().__init__() super().__init__()
@ -577,6 +596,10 @@ class Conv2dSubsampling(nn.Module):
channel_dim=-1, min_positive=0.45, max_positive=0.55 channel_dim=-1, min_positive=0.45, max_positive=0.55
) )
# ncnn supports only batch size == 1
self.is_pnnx = is_pnnx
self.conv_out_dim = self.out.weight.shape[1]
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x. """Subsample x.
@ -590,9 +613,15 @@ class Conv2dSubsampling(nn.Module):
# On entry, x is (N, T, idim) # On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x) x = self.conv(x)
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
b, c, t, f = x.size() if torch.jit.is_tracing() and self.is_pnnx:
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
x = self.out(x)
else:
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-3)//2-1))//2, odim) # Now x is of shape (N, ((T-3)//2-1))//2, odim)
x = self.out_norm(x) x = self.out_norm(x)
x = self.out_balancer(x) x = self.out_balancer(x)

View File

@ -102,7 +102,28 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--encoder-dim", "--encoder-dim",
type=int, type=int,
default=512, default=512,
help="Encoder output dimesion.", help="Encoder output dimension.",
)
parser.add_argument(
"--decoder-dim",
type=int,
default=512,
help="Decoder output dimension.",
)
parser.add_argument(
"--joiner-dim",
type=int,
default=512,
help="Joiner output dimension.",
)
parser.add_argument(
"--dim-feedforward",
type=int,
default=2048,
help="Dimension of feed forward.",
) )
parser.add_argument( parser.add_argument(
@ -395,14 +416,10 @@ def get_params() -> AttributeDict:
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
"dim_feedforward": 2048,
# parameters for decoder
"decoder_dim": 512,
# parameters for joiner
"joiner_dim": 512,
# parameters for Noam # parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate "model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(), "env_info": get_env_info(),
"is_pnnx": False,
} }
) )
@ -419,6 +436,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
aux_layer_period=params.aux_layer_period, aux_layer_period=params.aux_layer_period,
is_pnnx=params.is_pnnx,
) )
return encoder return encoder

View File

@ -12,9 +12,12 @@ stop_stage=100
# directories and files. If not, they will be downloaded # directories and files. If not, they will be downloaded
# by this script automatically. # by this script automatically.
# #
# - $dl_dir/tal_csasr # - $dl_dir/TALCS_corpus
# You can find three directories:train_set, dev_set, and test_set. # You can find three directories:train_set, dev_set, and test_set.
# You can get it from https://ai.100tal.com/dataset # You can get it from https://ai.100tal.com/dataset
# - dev_set
# - test_set
# - train_set
# #
# - $dl_dir/musan # - $dl_dir/musan
# This directory contains the following directories downloaded from # This directory contains the following directories downloaded from
@ -44,7 +47,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data" log "Stage 0: Download data"
# Before you run this script, you must get the TAL_CSASR dataset # Before you run this script, you must get the TAL_CSASR dataset
# from https://ai.100tal.com/dataset # from https://ai.100tal.com/dataset
mv $dl_dir/TALCS_corpus $dl_dir/tal_csasr if [ ! -d $dl_dir/tal_csasr/TALCS_corpus ]; then
mv $dl_dir/TALCS_corpus $dl_dir/tal_csasr
fi
# If you have pre-downloaded it to /path/to/TALCS_corpus, # If you have pre-downloaded it to /path/to/TALCS_corpus,
# you can create a symlink # you can create a symlink
@ -116,7 +121,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
fi fi
# Prepare text. # Prepare text.
# Note: in Linux, you can install jq with the following command: # Note: in Linux, you can install jq with the following command:
# 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 # 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
# 2. chmod +x ./jq # 2. chmod +x ./jq
# 3. cp jq /usr/bin # 3. cp jq /usr/bin