This commit is contained in:
yfyeung 2024-07-04 14:53:49 +08:00
parent 18da3e8975
commit 53da22ecc7
2 changed files with 117 additions and 3090 deletions

File diff suppressed because it is too large Load Diff

View File

@ -63,7 +63,6 @@ from icefall.checkpoint import (
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
@ -148,26 +147,17 @@ def get_parser():
""",
)
# NOTE: decoder params
parser.add_argument(
"--lstm-type",
type=str,
default="lstm",
choices=["lstm", "slstm", "mlstm", "xlstm"],
help="Implementation of LSTM in the decoder.",
)
parser.add_argument(
"--num-decoder-layers",
type=int,
default=4,
default=1,
help="Number of decoder layer of the LSTM decoder.",
)
parser.add_argument(
"--decoder-embedding-dim",
type=int,
default=1024,
default=512,
help="The embedding dimension of the LSTM decoder.",
)
@ -426,7 +416,9 @@ def decode_one_batch(
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":