mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
update docs
This commit is contained in:
parent
b62fd917ae
commit
2a52b8c125
@ -235,7 +235,7 @@ def get_parser():
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
- modified_beam_search_ngram_rescoring
|
||||
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
||||
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
@ -329,7 +329,7 @@ def get_parser():
|
||||
"--rnn-lm-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""Used only when --method is modified_beam_search3.
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the path to RNN LM exp dir.
|
||||
""",
|
||||
)
|
||||
@ -338,7 +338,7 @@ def get_parser():
|
||||
"--rnn-lm-exp-dir",
|
||||
type=str,
|
||||
default="rnn_lm/exp",
|
||||
help="""Used only when --method is rnn-lm.
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the path to RNN LM exp dir.
|
||||
""",
|
||||
)
|
||||
@ -347,7 +347,7 @@ def get_parser():
|
||||
"--rnn-lm-epoch",
|
||||
type=int,
|
||||
default=7,
|
||||
help="""Used only when --method is rnn-lm.
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the checkpoint to use.
|
||||
""",
|
||||
)
|
||||
@ -356,7 +356,7 @@ def get_parser():
|
||||
"--rnn-lm-avg",
|
||||
type=int,
|
||||
default=2,
|
||||
help="""Used only when --method is rnn-lm.
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the number of checkpoints to average.
|
||||
""",
|
||||
)
|
||||
@ -911,14 +911,20 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||
logging.info(f"lm filename: {lm_filename}")
|
||||
ngram_lm = NgramLm(
|
||||
str(params.lang_dir / lm_filename),
|
||||
backoff_id=params.backoff_id,
|
||||
is_binary=False,
|
||||
)
|
||||
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||
# only load N-gram LM when needed
|
||||
if "ngram" in params.decoding_method:
|
||||
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||
logging.info(f"lm filename: {lm_filename}")
|
||||
ngram_lm = NgramLm(
|
||||
str(params.lang_dir / lm_filename),
|
||||
backoff_id=params.backoff_id,
|
||||
is_binary=False,
|
||||
)
|
||||
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||
else:
|
||||
ngram_lm = None
|
||||
ngram_lm_scale = None
|
||||
|
||||
# only load rnnlm if used
|
||||
if "rnnlm" in params.decoding_method:
|
||||
rnn_lm_scale = params.rnn_lm_scale
|
||||
@ -941,6 +947,7 @@ def main():
|
||||
|
||||
else:
|
||||
rnn_lm_model = None
|
||||
rnn_lm_scale = 0.0
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
@ -987,7 +994,7 @@ def main():
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=params.ngram_lm_scale,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
rnnlm=rnn_lm_model,
|
||||
rnnlm_scale=rnn_lm_scale,
|
||||
)
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
@ -729,8 +729,15 @@ class Hypothesis:
|
||||
|
||||
# timestamp[i] is the frame index after subsampling
|
||||
# on which ys[i] is decoded
|
||||
timestamp: List[int]
|
||||
timestamp: List[int] = None
|
||||
|
||||
# the lm score for next token given the current ys
|
||||
lm_score: Optional[torch.Tensor] = None
|
||||
|
||||
# the RNNLM states (h and c in LSTM)
|
||||
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||
|
||||
# N-gram LM state
|
||||
state_cost: Optional[NgramLmStateCost] = None
|
||||
|
||||
@property
|
||||
@ -1989,8 +1996,15 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
||||
ragged_log_probs = k2.RaggedTensor(
|
||||
shape=log_probs_shape, value=log_probs
|
||||
)
|
||||
|
||||
# for all hyps with a non-blank new token, score it
|
||||
"""
|
||||
for all hyps with a non-blank new token, score this token.
|
||||
It is a little confusing here because this for-loop
|
||||
looks very similar to the one below. Here, we go through all
|
||||
top-k tokens and only add the non-blanks ones to the token_list.
|
||||
The RNNLM will score those tokens given the LM states. Note that
|
||||
the variable `scores` is the LM score after seeing the new
|
||||
non-blank token.
|
||||
"""
|
||||
token_list = []
|
||||
hs = []
|
||||
cs = []
|
||||
@ -2007,11 +2021,12 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
||||
|
||||
new_token = topk_token_indexes[k]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
|
||||
assert new_token != 0, new_token
|
||||
token_list.append([new_token])
|
||||
# store the LSTM states
|
||||
hs.append(hyp.state[0])
|
||||
cs.append(hyp.state[1])
|
||||
|
||||
# forward RNNLM to get new states and scores
|
||||
if len(token_list) != 0:
|
||||
tokens_to_score = (
|
||||
|
@ -228,7 +228,7 @@ def get_parser():
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
||||
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
@ -354,7 +354,7 @@ def get_parser():
|
||||
"--rnn-lm-exp-dir",
|
||||
type=str,
|
||||
default="rnn_lm/exp",
|
||||
help="""Used only when --method is rnn-lm.
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the path to RNN LM exp dir.
|
||||
""",
|
||||
)
|
||||
@ -363,7 +363,7 @@ def get_parser():
|
||||
"--rnn-lm-epoch",
|
||||
type=int,
|
||||
default=7,
|
||||
help="""Used only when --method is rnn-lm.
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the checkpoint to use.
|
||||
""",
|
||||
)
|
||||
@ -372,7 +372,7 @@ def get_parser():
|
||||
"--rnn-lm-avg",
|
||||
type=int,
|
||||
default=2,
|
||||
help="""Used only when --method is rnn-lm.
|
||||
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||
It specifies the number of checkpoints to average.
|
||||
""",
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user