add args in streaming-ncnn-decode.py

This commit is contained in:
marcoyang 2023-02-13 12:45:47 +08:00
parent c689b018d7
commit eb9a5267a9

View File

@ -77,6 +77,27 @@ def get_args():
type=str,
help="Path to joiner.ncnn.bin",
)
parser.add_argument(
"--num-encoder-layers",
type=int,
default=12,
help="Number of RNN encoder layers..",
)
parser.add_argument(
"--encoder-dim",
type=int,
default=512,
help="Encoder output dimesion.",
)
parser.add_argument(
"--rnn-hidden-size",
type=int,
default=2048,
help="Dimension of feed forward.",
)
parser.add_argument(
"sound_filename",
@ -270,10 +291,10 @@ def main():
)[0]
logging.info(wave_samples.shape)
num_encoder_layers = 12
num_encoder_layers = args.num_encoder_layers
batch_size = 1
d_model = 512
rnn_hidden_size = 1024
d_model = args.encoder_dim
rnn_hidden_size = args.rnn_hidden_size
states = (
torch.zeros(num_encoder_layers, batch_size, d_model),