minor fixes

This commit is contained in:
marcoyang 2022-11-02 18:11:39 +08:00
parent 9a01b9098d
commit fb45b95c90

View File

@ -86,7 +86,7 @@ Usage:
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(8) modified beam search with RNNLM shallow fusion (with LG)
./pruned_transducer_stateless5/decode.py \
--epoch 35 \
@ -103,7 +103,7 @@ Usage:
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
"""
@ -198,7 +198,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="lstm_transducer_stateless2/exp",
default="pruned_transducer_stateless5/exp",
help="The experiment dir",
)
@ -228,7 +228,7 @@ def get_parser():
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
- modified-beam-search3 # for rnn lm shallow fusion
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
@ -265,7 +265,21 @@ def get_parser():
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--max-contexts",
type=int,
@ -317,7 +331,7 @@ def get_parser():
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
@ -331,7 +345,7 @@ def get_parser():
"--rnn-lm-scale",
type=float,
default=0.0,
help="""Used only when --method is modified_beam_search3.
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
@ -430,7 +444,7 @@ def decode_one_batch(
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
@ -560,7 +574,7 @@ def decode_one_batch(
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
@ -606,7 +620,7 @@ def decode_dataset(
decoding_graph: Optional[k2.Fsa] = None,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
@ -683,7 +697,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
@ -751,7 +765,9 @@ def main():
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
@ -791,6 +807,11 @@ def main():
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")