mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 02:52:18 +00:00
update
This commit is contained in:
parent
18da3e8975
commit
53da22ecc7
File diff suppressed because it is too large
Load Diff
@ -63,7 +63,6 @@ from icefall.checkpoint import (
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
make_pad_mask,
|
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
str2bool,
|
str2bool,
|
||||||
@ -148,26 +147,17 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: decoder params
|
|
||||||
parser.add_argument(
|
|
||||||
"--lstm-type",
|
|
||||||
type=str,
|
|
||||||
default="lstm",
|
|
||||||
choices=["lstm", "slstm", "mlstm", "xlstm"],
|
|
||||||
help="Implementation of LSTM in the decoder.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-decoder-layers",
|
"--num-decoder-layers",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=1,
|
||||||
help="Number of decoder layer of the LSTM decoder.",
|
help="Number of decoder layer of the LSTM decoder.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoder-embedding-dim",
|
"--decoder-embedding-dim",
|
||||||
type=int,
|
type=int,
|
||||||
default=1024,
|
default=512,
|
||||||
help="The embedding dimension of the LSTM decoder.",
|
help="The embedding dimension of the LSTM decoder.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -426,7 +416,9 @@ def decode_one_batch(
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
)
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(sp.decode(hyp).split())
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user