minor fixes

This commit is contained in:
marcoyang 2022-11-02 18:11:39 +08:00
parent 9a01b9098d
commit fb45b95c90

View File

@ -198,7 +198,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="lstm_transducer_stateless2/exp", default="pruned_transducer_stateless5/exp",
help="The experiment dir", help="The experiment dir",
) )
@ -228,7 +228,7 @@ def get_parser():
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG - fast_beam_search_nbest_LG
- modified-beam-search3 # for rnn lm shallow fusion - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
If you use fast_beam_search_nbest_LG, you have to specify If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`. `--lang-dir`, which should contain `LG.pt`.
""", """,
@ -266,6 +266,20 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
@ -331,7 +345,7 @@ def get_parser():
"--rnn-lm-scale", "--rnn-lm-scale",
type=float, type=float,
default=0.0, default=0.0,
help="""Used only when --method is modified_beam_search3. help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir. It specifies the path to RNN LM exp dir.
""", """,
) )
@ -430,7 +444,7 @@ def decode_one_batch(
word_table: word_table:
The word symbol table. The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns: Returns:
@ -606,7 +620,7 @@ def decode_dataset(
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
rnnlm: Optional[RnnLmModel] = None, rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0, rnnlm_scale: float = 1.0,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
Args: Args:
@ -683,7 +697,7 @@ def decode_dataset(
def save_results( def save_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]], results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
): ):
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
@ -751,7 +765,9 @@ def main():
params.suffix = f"iter-{params.iter}-avg-{params.avg}" params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
@ -791,6 +807,11 @@ def main():
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")