mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
minor fixes
This commit is contained in:
parent
9a01b9098d
commit
fb45b95c90
@ -86,7 +86,7 @@ Usage:
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
|
||||
(8) modified beam search with RNNLM shallow fusion (with LG)
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 35 \
|
||||
@ -103,7 +103,7 @@ Usage:
|
||||
--rnn-lm-num-layers 3 \
|
||||
--rnn-lm-tie-weights 1
|
||||
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@ -198,7 +198,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="lstm_transducer_stateless2/exp",
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
@ -228,7 +228,7 @@ def get_parser():
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
- modified-beam-search3 # for rnn lm shallow fusion
|
||||
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
@ -265,7 +265,21 @@ def get_parser():
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
@ -317,7 +331,7 @@ def get_parser():
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--simulate-streaming",
|
||||
type=str2bool,
|
||||
@ -331,7 +345,7 @@ def get_parser():
|
||||
"--rnn-lm-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""Used only when --method is modified_beam_search3.
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the path to RNN LM exp dir.
|
||||
""",
|
||||
)
|
||||
@ -430,7 +444,7 @@ def decode_one_batch(
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
@ -560,7 +574,7 @@ def decode_one_batch(
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
@ -606,7 +620,7 @@ def decode_dataset(
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
rnnlm: Optional[RnnLmModel] = None,
|
||||
rnnlm_scale: float = 1.0,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
@ -683,7 +697,7 @@ def decode_dataset(
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
@ -751,7 +765,9 @@ def main():
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.simulate_streaming:
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context}"
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
@ -791,6 +807,11 @@ def main():
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.simulate_streaming:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
), "Decoding in streaming requires causal convolution"
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
Loading…
x
Reference in New Issue
Block a user