mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'master' of github.com:marcoyang1998/icefall into add_lstm_transducer
This commit is contained in:
commit
b3fa59d68a
@ -0,0 +1 @@
|
|||||||
|
../lstm_transducer_stateless2/export-for-ncnn.py
|
||||||
@ -121,6 +121,8 @@ class RNN(EncoderInterface):
|
|||||||
Period of auxiliary layers used for random combiner during training.
|
Period of auxiliary layers used for random combiner during training.
|
||||||
If set to 0, will not use the random combiner (Default).
|
If set to 0, will not use the random combiner (Default).
|
||||||
You can set a positive integer to use the random combiner, e.g., 3.
|
You can set a positive integer to use the random combiner, e.g., 3.
|
||||||
|
is_pnnx:
|
||||||
|
True to make this class exportable via PNNX.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -135,6 +137,7 @@ class RNN(EncoderInterface):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.075,
|
layer_dropout: float = 0.075,
|
||||||
aux_layer_period: int = 0,
|
aux_layer_period: int = 0,
|
||||||
|
is_pnnx: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(RNN, self).__init__()
|
super(RNN, self).__init__()
|
||||||
|
|
||||||
@ -148,7 +151,13 @@ class RNN(EncoderInterface):
|
|||||||
# That is, it does two things simultaneously:
|
# That is, it does two things simultaneously:
|
||||||
# (1) subsampling: T -> T//subsampling_factor
|
# (1) subsampling: T -> T//subsampling_factor
|
||||||
# (2) embedding: num_features -> d_model
|
# (2) embedding: num_features -> d_model
|
||||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
self.encoder_embed = Conv2dSubsampling(
|
||||||
|
num_features,
|
||||||
|
d_model,
|
||||||
|
is_pnnx=is_pnnx,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.is_pnnx = is_pnnx
|
||||||
|
|
||||||
self.num_encoder_layers = num_encoder_layers
|
self.num_encoder_layers = num_encoder_layers
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
@ -216,7 +225,13 @@ class RNN(EncoderInterface):
|
|||||||
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
|
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
|
||||||
#
|
#
|
||||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||||
lengths = (((x_lens - 3) >> 1) - 1) >> 1
|
if not self.is_pnnx:
|
||||||
|
lengths = (((x_lens - 3) >> 1) - 1) >> 1
|
||||||
|
else:
|
||||||
|
lengths1 = torch.floor((x_lens - 3) / 2)
|
||||||
|
lengths = torch.floor((lengths1 - 1) / 2)
|
||||||
|
lengths = lengths.to(x_lens)
|
||||||
|
|
||||||
if not torch.jit.is_tracing():
|
if not torch.jit.is_tracing():
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
|
|
||||||
@ -377,7 +392,7 @@ class RNNEncoderLayer(nn.Module):
|
|||||||
# for cell state
|
# for cell state
|
||||||
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
|
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
|
||||||
src_lstm, new_states = self.lstm(src, states)
|
src_lstm, new_states = self.lstm(src, states)
|
||||||
src = src + self.dropout(src_lstm)
|
src = self.dropout(src_lstm) + src
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.dropout(self.feed_forward(src))
|
src = src + self.dropout(self.feed_forward(src))
|
||||||
@ -523,6 +538,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
layer1_channels: int = 8,
|
layer1_channels: int = 8,
|
||||||
layer2_channels: int = 32,
|
layer2_channels: int = 32,
|
||||||
layer3_channels: int = 128,
|
layer3_channels: int = 128,
|
||||||
|
is_pnnx: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -535,6 +551,9 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
Number of channels in layer1
|
Number of channels in layer1
|
||||||
layer1_channels:
|
layer1_channels:
|
||||||
Number of channels in layer2
|
Number of channels in layer2
|
||||||
|
is_pnnx:
|
||||||
|
True if we are converting the model to PNNX format.
|
||||||
|
False otherwise.
|
||||||
"""
|
"""
|
||||||
assert in_channels >= 9
|
assert in_channels >= 9
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -577,6 +596,10 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ncnn supports only batch size == 1
|
||||||
|
self.is_pnnx = is_pnnx
|
||||||
|
self.conv_out_dim = self.out.weight.shape[1]
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Subsample x.
|
"""Subsample x.
|
||||||
|
|
||||||
@ -590,9 +613,15 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
# On entry, x is (N, T, idim)
|
# On entry, x is (N, T, idim)
|
||||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
|
|
||||||
b, c, t, f = x.size()
|
if torch.jit.is_tracing() and self.is_pnnx:
|
||||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim)
|
||||||
|
x = self.out(x)
|
||||||
|
else:
|
||||||
|
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
|
|
||||||
# Now x is of shape (N, ((T-3)//2-1))//2, odim)
|
# Now x is of shape (N, ((T-3)//2-1))//2, odim)
|
||||||
x = self.out_norm(x)
|
x = self.out_norm(x)
|
||||||
x = self.out_balancer(x)
|
x = self.out_balancer(x)
|
||||||
|
|||||||
@ -102,7 +102,28 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
"--encoder-dim",
|
"--encoder-dim",
|
||||||
type=int,
|
type=int,
|
||||||
default=512,
|
default=512,
|
||||||
help="Encoder output dimesion.",
|
help="Encoder output dimension.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-dim",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Decoder output dimension.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-dim",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Joiner output dimension.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dim-feedforward",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Dimension of feed forward.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -395,14 +416,10 @@ def get_params() -> AttributeDict:
|
|||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"dim_feedforward": 2048,
|
|
||||||
# parameters for decoder
|
|
||||||
"decoder_dim": 512,
|
|
||||||
# parameters for joiner
|
|
||||||
"joiner_dim": 512,
|
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
|
"is_pnnx": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -419,6 +436,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
aux_layer_period=params.aux_layer_period,
|
aux_layer_period=params.aux_layer_period,
|
||||||
|
is_pnnx=params.is_pnnx,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|||||||
@ -12,9 +12,12 @@ stop_stage=100
|
|||||||
# directories and files. If not, they will be downloaded
|
# directories and files. If not, they will be downloaded
|
||||||
# by this script automatically.
|
# by this script automatically.
|
||||||
#
|
#
|
||||||
# - $dl_dir/tal_csasr
|
# - $dl_dir/TALCS_corpus
|
||||||
# You can find three directories:train_set, dev_set, and test_set.
|
# You can find three directories:train_set, dev_set, and test_set.
|
||||||
# You can get it from https://ai.100tal.com/dataset
|
# You can get it from https://ai.100tal.com/dataset
|
||||||
|
# - dev_set
|
||||||
|
# - test_set
|
||||||
|
# - train_set
|
||||||
#
|
#
|
||||||
# - $dl_dir/musan
|
# - $dl_dir/musan
|
||||||
# This directory contains the following directories downloaded from
|
# This directory contains the following directories downloaded from
|
||||||
@ -44,7 +47,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|||||||
log "Stage 0: Download data"
|
log "Stage 0: Download data"
|
||||||
# Before you run this script, you must get the TAL_CSASR dataset
|
# Before you run this script, you must get the TAL_CSASR dataset
|
||||||
# from https://ai.100tal.com/dataset
|
# from https://ai.100tal.com/dataset
|
||||||
mv $dl_dir/TALCS_corpus $dl_dir/tal_csasr
|
if [ ! -d $dl_dir/tal_csasr/TALCS_corpus ]; then
|
||||||
|
mv $dl_dir/TALCS_corpus $dl_dir/tal_csasr
|
||||||
|
fi
|
||||||
|
|
||||||
# If you have pre-downloaded it to /path/to/TALCS_corpus,
|
# If you have pre-downloaded it to /path/to/TALCS_corpus,
|
||||||
# you can create a symlink
|
# you can create a symlink
|
||||||
@ -116,7 +121,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# Prepare text.
|
# Prepare text.
|
||||||
# Note: in Linux, you can install jq with the following command:
|
# Note: in Linux, you can install jq with the following command:
|
||||||
# 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
|
# 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64
|
||||||
# 2. chmod +x ./jq
|
# 2. chmod +x ./jq
|
||||||
# 3. cp jq /usr/bin
|
# 3. cp jq /usr/bin
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user