mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
minor fixes
This commit is contained in:
parent
9a01b9098d
commit
fb45b95c90
@ -198,7 +198,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="lstm_transducer_stateless2/exp",
|
default="pruned_transducer_stateless5/exp",
|
||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -228,7 +228,7 @@ def get_parser():
|
|||||||
- fast_beam_search_nbest
|
- fast_beam_search_nbest
|
||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
- fast_beam_search_nbest_LG
|
- fast_beam_search_nbest_LG
|
||||||
- modified-beam-search3 # for rnn lm shallow fusion
|
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
||||||
If you use fast_beam_search_nbest_LG, you have to specify
|
If you use fast_beam_search_nbest_LG, you have to specify
|
||||||
`--lang-dir`, which should contain `LG.pt`.
|
`--lang-dir`, which should contain `LG.pt`.
|
||||||
""",
|
""",
|
||||||
@ -266,6 +266,20 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="The chunk size for decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--left-context",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-contexts",
|
"--max-contexts",
|
||||||
type=int,
|
type=int,
|
||||||
@ -331,7 +345,7 @@ def get_parser():
|
|||||||
"--rnn-lm-scale",
|
"--rnn-lm-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="""Used only when --method is modified_beam_search3.
|
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||||
It specifies the path to RNN LM exp dir.
|
It specifies the path to RNN LM exp dir.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -430,7 +444,7 @@ def decode_one_batch(
|
|||||||
word_table:
|
word_table:
|
||||||
The word symbol table.
|
The word symbol table.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
Returns:
|
Returns:
|
||||||
@ -560,7 +574,7 @@ def decode_one_batch(
|
|||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
hyp = greedy_search(
|
hyp = greedy_search(
|
||||||
@ -606,7 +620,7 @@ def decode_dataset(
|
|||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
rnnlm: Optional[RnnLmModel] = None,
|
rnnlm: Optional[RnnLmModel] = None,
|
||||||
rnnlm_scale: float = 1.0,
|
rnnlm_scale: float = 1.0,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -683,7 +697,7 @@ def decode_dataset(
|
|||||||
def save_results(
|
def save_results(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
@ -751,7 +765,9 @@ def main():
|
|||||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
if params.simulate_streaming:
|
||||||
|
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||||
|
params.suffix += f"-left-context-{params.left_context}"
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
@ -791,6 +807,11 @@ def main():
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
assert (
|
||||||
|
params.causal_convolution
|
||||||
|
), "Decoding in streaming requires causal convolution"
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user