mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
update docs
This commit is contained in:
parent
b62fd917ae
commit
2a52b8c125
@ -235,7 +235,7 @@ def get_parser():
|
|||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
- fast_beam_search_nbest_LG
|
- fast_beam_search_nbest_LG
|
||||||
- modified_beam_search_ngram_rescoring
|
- modified_beam_search_ngram_rescoring
|
||||||
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
||||||
If you use fast_beam_search_nbest_LG, you have to specify
|
If you use fast_beam_search_nbest_LG, you have to specify
|
||||||
`--lang-dir`, which should contain `LG.pt`.
|
`--lang-dir`, which should contain `LG.pt`.
|
||||||
""",
|
""",
|
||||||
@ -329,7 +329,7 @@ def get_parser():
|
|||||||
"--rnn-lm-scale",
|
"--rnn-lm-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="""Used only when --method is modified_beam_search3.
|
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||||
It specifies the path to RNN LM exp dir.
|
It specifies the path to RNN LM exp dir.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -338,7 +338,7 @@ def get_parser():
|
|||||||
"--rnn-lm-exp-dir",
|
"--rnn-lm-exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="rnn_lm/exp",
|
default="rnn_lm/exp",
|
||||||
help="""Used only when --method is rnn-lm.
|
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||||
It specifies the path to RNN LM exp dir.
|
It specifies the path to RNN LM exp dir.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -347,7 +347,7 @@ def get_parser():
|
|||||||
"--rnn-lm-epoch",
|
"--rnn-lm-epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=7,
|
default=7,
|
||||||
help="""Used only when --method is rnn-lm.
|
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||||
It specifies the checkpoint to use.
|
It specifies the checkpoint to use.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -356,7 +356,7 @@ def get_parser():
|
|||||||
"--rnn-lm-avg",
|
"--rnn-lm-avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="""Used only when --method is rnn-lm.
|
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||||
It specifies the number of checkpoints to average.
|
It specifies the number of checkpoints to average.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -911,6 +911,8 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
# only load N-gram LM when needed
|
||||||
|
if "ngram" in params.decoding_method:
|
||||||
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
|
||||||
logging.info(f"lm filename: {lm_filename}")
|
logging.info(f"lm filename: {lm_filename}")
|
||||||
ngram_lm = NgramLm(
|
ngram_lm = NgramLm(
|
||||||
@ -919,6 +921,10 @@ def main():
|
|||||||
is_binary=False,
|
is_binary=False,
|
||||||
)
|
)
|
||||||
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||||
|
else:
|
||||||
|
ngram_lm = None
|
||||||
|
ngram_lm_scale = None
|
||||||
|
|
||||||
# only load rnnlm if used
|
# only load rnnlm if used
|
||||||
if "rnnlm" in params.decoding_method:
|
if "rnnlm" in params.decoding_method:
|
||||||
rnn_lm_scale = params.rnn_lm_scale
|
rnn_lm_scale = params.rnn_lm_scale
|
||||||
@ -941,6 +947,7 @@ def main():
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
rnn_lm_model = None
|
rnn_lm_model = None
|
||||||
|
rnn_lm_scale = 0.0
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
@ -987,7 +994,7 @@ def main():
|
|||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
ngram_lm=ngram_lm,
|
ngram_lm=ngram_lm,
|
||||||
ngram_lm_scale=params.ngram_lm_scale,
|
ngram_lm_scale=ngram_lm_scale,
|
||||||
rnnlm=rnn_lm_model,
|
rnnlm=rnn_lm_model,
|
||||||
rnnlm_scale=rnn_lm_scale,
|
rnnlm_scale=rnn_lm_scale,
|
||||||
)
|
)
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
@ -729,8 +729,15 @@ class Hypothesis:
|
|||||||
|
|
||||||
# timestamp[i] is the frame index after subsampling
|
# timestamp[i] is the frame index after subsampling
|
||||||
# on which ys[i] is decoded
|
# on which ys[i] is decoded
|
||||||
timestamp: List[int]
|
timestamp: List[int] = None
|
||||||
|
|
||||||
|
# the lm score for next token given the current ys
|
||||||
|
lm_score: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# the RNNLM states (h and c in LSTM)
|
||||||
|
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||||
|
|
||||||
|
# N-gram LM state
|
||||||
state_cost: Optional[NgramLmStateCost] = None
|
state_cost: Optional[NgramLmStateCost] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1989,8 +1996,15 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
|||||||
ragged_log_probs = k2.RaggedTensor(
|
ragged_log_probs = k2.RaggedTensor(
|
||||||
shape=log_probs_shape, value=log_probs
|
shape=log_probs_shape, value=log_probs
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
# for all hyps with a non-blank new token, score it
|
for all hyps with a non-blank new token, score this token.
|
||||||
|
It is a little confusing here because this for-loop
|
||||||
|
looks very similar to the one below. Here, we go through all
|
||||||
|
top-k tokens and only add the non-blanks ones to the token_list.
|
||||||
|
The RNNLM will score those tokens given the LM states. Note that
|
||||||
|
the variable `scores` is the LM score after seeing the new
|
||||||
|
non-blank token.
|
||||||
|
"""
|
||||||
token_list = []
|
token_list = []
|
||||||
hs = []
|
hs = []
|
||||||
cs = []
|
cs = []
|
||||||
@ -2007,11 +2021,12 @@ def modified_beam_search_rnnlm_shallow_fusion(
|
|||||||
|
|
||||||
new_token = topk_token_indexes[k]
|
new_token = topk_token_indexes[k]
|
||||||
if new_token not in (blank_id, unk_id):
|
if new_token not in (blank_id, unk_id):
|
||||||
|
|
||||||
assert new_token != 0, new_token
|
assert new_token != 0, new_token
|
||||||
token_list.append([new_token])
|
token_list.append([new_token])
|
||||||
|
# store the LSTM states
|
||||||
hs.append(hyp.state[0])
|
hs.append(hyp.state[0])
|
||||||
cs.append(hyp.state[1])
|
cs.append(hyp.state[1])
|
||||||
|
|
||||||
# forward RNNLM to get new states and scores
|
# forward RNNLM to get new states and scores
|
||||||
if len(token_list) != 0:
|
if len(token_list) != 0:
|
||||||
tokens_to_score = (
|
tokens_to_score = (
|
||||||
|
@ -228,7 +228,7 @@ def get_parser():
|
|||||||
- fast_beam_search_nbest
|
- fast_beam_search_nbest
|
||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
- fast_beam_search_nbest_LG
|
- fast_beam_search_nbest_LG
|
||||||
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
|
||||||
If you use fast_beam_search_nbest_LG, you have to specify
|
If you use fast_beam_search_nbest_LG, you have to specify
|
||||||
`--lang-dir`, which should contain `LG.pt`.
|
`--lang-dir`, which should contain `LG.pt`.
|
||||||
""",
|
""",
|
||||||
@ -354,7 +354,7 @@ def get_parser():
|
|||||||
"--rnn-lm-exp-dir",
|
"--rnn-lm-exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="rnn_lm/exp",
|
default="rnn_lm/exp",
|
||||||
help="""Used only when --method is rnn-lm.
|
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||||
It specifies the path to RNN LM exp dir.
|
It specifies the path to RNN LM exp dir.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -363,7 +363,7 @@ def get_parser():
|
|||||||
"--rnn-lm-epoch",
|
"--rnn-lm-epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=7,
|
default=7,
|
||||||
help="""Used only when --method is rnn-lm.
|
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||||
It specifies the checkpoint to use.
|
It specifies the checkpoint to use.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -372,7 +372,7 @@ def get_parser():
|
|||||||
"--rnn-lm-avg",
|
"--rnn-lm-avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="""Used only when --method is rnn-lm.
|
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
|
||||||
It specifies the number of checkpoints to average.
|
It specifies the number of checkpoints to average.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user