mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
add args in streaming-ncnn-decode.py
This commit is contained in:
parent
c689b018d7
commit
eb9a5267a9
@ -77,6 +77,27 @@ def get_args():
|
||||
type=str,
|
||||
help="Path to joiner.ncnn.bin",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
type=int,
|
||||
default=12,
|
||||
help="Number of RNN encoder layers..",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Encoder output dimesion.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-hidden-size",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Dimension of feed forward.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_filename",
|
||||
@ -270,10 +291,10 @@ def main():
|
||||
)[0]
|
||||
logging.info(wave_samples.shape)
|
||||
|
||||
num_encoder_layers = 12
|
||||
num_encoder_layers = args.num_encoder_layers
|
||||
batch_size = 1
|
||||
d_model = 512
|
||||
rnn_hidden_size = 1024
|
||||
d_model = args.encoder_dim
|
||||
rnn_hidden_size = args.rnn_hidden_size
|
||||
|
||||
states = (
|
||||
torch.zeros(num_encoder_layers, batch_size, d_model),
|
||||
|
Loading…
x
Reference in New Issue
Block a user