update docs

This commit is contained in:
marcoyang 2022-11-03 11:10:21 +08:00
parent b62fd917ae
commit 2a52b8c125
3 changed files with 45 additions and 23 deletions

View File

@ -235,7 +235,7 @@ def get_parser():
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
- modified_beam_search_ngram_rescoring
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
@ -329,7 +329,7 @@ def get_parser():
"--rnn-lm-scale",
type=float,
default=0.0,
help="""Used only when --method is modified_beam_search3.
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
@ -338,7 +338,7 @@ def get_parser():
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is rnn-lm.
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
@ -347,7 +347,7 @@ def get_parser():
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the checkpoint to use.
""",
)
@ -356,7 +356,7 @@ def get_parser():
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the number of checkpoints to average.
""",
)
@ -911,14 +911,20 @@ def main():
model.to(device)
model.eval()
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"lm filename: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")
# only load N-gram LM when needed
if "ngram" in params.decoding_method:
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"lm filename: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")
else:
ngram_lm = None
ngram_lm_scale = None
# only load rnnlm if used
if "rnnlm" in params.decoding_method:
rnn_lm_scale = params.rnn_lm_scale
@ -941,6 +947,7 @@ def main():
else:
rnn_lm_model = None
rnn_lm_scale = 0.0
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
@ -987,7 +994,7 @@ def main():
word_table=word_table,
decoding_graph=decoding_graph,
ngram_lm=ngram_lm,
ngram_lm_scale=params.ngram_lm_scale,
ngram_lm_scale=ngram_lm_scale,
rnnlm=rnn_lm_model,
rnnlm_scale=rnn_lm_scale,
)

View File

@ -17,7 +17,7 @@
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import k2
import sentencepiece as spm
@ -729,8 +729,15 @@ class Hypothesis:
# timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int]
timestamp: List[int] = None
# the lm score for next token given the current ys
lm_score: Optional[torch.Tensor] = None
# the RNNLM states (h and c in LSTM)
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
# N-gram LM state
state_cost: Optional[NgramLmStateCost] = None
@property
@ -1989,8 +1996,15 @@ def modified_beam_search_rnnlm_shallow_fusion(
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
# for all hyps with a non-blank new token, score it
"""
for all hyps with a non-blank new token, score this token.
It is a little confusing here because this for-loop
looks very similar to the one below. Here, we go through all
top-k tokens and only add the non-blanks ones to the token_list.
The RNNLM will score those tokens given the LM states. Note that
the variable `scores` is the LM score after seeing the new
non-blank token.
"""
token_list = []
hs = []
cs = []
@ -2007,11 +2021,12 @@ def modified_beam_search_rnnlm_shallow_fusion(
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
assert new_token != 0, new_token
token_list.append([new_token])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
# forward RNNLM to get new states and scores
if len(token_list) != 0:
tokens_to_score = (

View File

@ -228,7 +228,7 @@ def get_parser():
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
@ -354,7 +354,7 @@ def get_parser():
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is rnn-lm.
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
@ -363,7 +363,7 @@ def get_parser():
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the checkpoint to use.
""",
)
@ -372,7 +372,7 @@ def get_parser():
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the number of checkpoints to average.
""",
)