update docs

This commit is contained in:
marcoyang 2022-11-03 11:10:21 +08:00
parent b62fd917ae
commit 2a52b8c125
3 changed files with 45 additions and 23 deletions

View File

@ -235,7 +235,7 @@ def get_parser():
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG - fast_beam_search_nbest_LG
- modified_beam_search_ngram_rescoring - modified_beam_search_ngram_rescoring
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
If you use fast_beam_search_nbest_LG, you have to specify If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`. `--lang-dir`, which should contain `LG.pt`.
""", """,
@ -329,7 +329,7 @@ def get_parser():
"--rnn-lm-scale", "--rnn-lm-scale",
type=float, type=float,
default=0.0, default=0.0,
help="""Used only when --method is modified_beam_search3. help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir. It specifies the path to RNN LM exp dir.
""", """,
) )
@ -338,7 +338,7 @@ def get_parser():
"--rnn-lm-exp-dir", "--rnn-lm-exp-dir",
type=str, type=str,
default="rnn_lm/exp", default="rnn_lm/exp",
help="""Used only when --method is rnn-lm. help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir. It specifies the path to RNN LM exp dir.
""", """,
) )
@ -347,7 +347,7 @@ def get_parser():
"--rnn-lm-epoch", "--rnn-lm-epoch",
type=int, type=int,
default=7, default=7,
help="""Used only when --method is rnn-lm. help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the checkpoint to use. It specifies the checkpoint to use.
""", """,
) )
@ -356,7 +356,7 @@ def get_parser():
"--rnn-lm-avg", "--rnn-lm-avg",
type=int, type=int,
default=2, default=2,
help="""Used only when --method is rnn-lm. help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the number of checkpoints to average. It specifies the number of checkpoints to average.
""", """,
) )
@ -911,6 +911,8 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
# only load N-gram LM when needed
if "ngram" in params.decoding_method:
lm_filename = f"{params.tokens_ngram}gram.fst.txt" lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"lm filename: {lm_filename}") logging.info(f"lm filename: {lm_filename}")
ngram_lm = NgramLm( ngram_lm = NgramLm(
@ -919,6 +921,10 @@ def main():
is_binary=False, is_binary=False,
) )
logging.info(f"num states: {ngram_lm.lm.num_states}") logging.info(f"num states: {ngram_lm.lm.num_states}")
else:
ngram_lm = None
ngram_lm_scale = None
# only load rnnlm if used # only load rnnlm if used
if "rnnlm" in params.decoding_method: if "rnnlm" in params.decoding_method:
rnn_lm_scale = params.rnn_lm_scale rnn_lm_scale = params.rnn_lm_scale
@ -941,6 +947,7 @@ def main():
else: else:
rnn_lm_model = None rnn_lm_model = None
rnn_lm_scale = 0.0
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if params.decoding_method == "fast_beam_search_nbest_LG":
@ -987,7 +994,7 @@ def main():
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
ngram_lm=ngram_lm, ngram_lm=ngram_lm,
ngram_lm_scale=params.ngram_lm_scale, ngram_lm_scale=ngram_lm_scale,
rnnlm=rnn_lm_model, rnnlm=rnn_lm_model,
rnnlm_scale=rnn_lm_scale, rnnlm_scale=rnn_lm_scale,
) )

View File

@ -17,7 +17,7 @@
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Tuple, Union
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
@ -729,8 +729,15 @@ class Hypothesis:
# timestamp[i] is the frame index after subsampling # timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded # on which ys[i] is decoded
timestamp: List[int] timestamp: List[int] = None
# the lm score for next token given the current ys
lm_score: Optional[torch.Tensor] = None
# the RNNLM states (h and c in LSTM)
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
# N-gram LM state
state_cost: Optional[NgramLmStateCost] = None state_cost: Optional[NgramLmStateCost] = None
@property @property
@ -1989,8 +1996,15 @@ def modified_beam_search_rnnlm_shallow_fusion(
ragged_log_probs = k2.RaggedTensor( ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs shape=log_probs_shape, value=log_probs
) )
"""
# for all hyps with a non-blank new token, score it for all hyps with a non-blank new token, score this token.
It is a little confusing here because this for-loop
looks very similar to the one below. Here, we go through all
top-k tokens and only add the non-blanks ones to the token_list.
The RNNLM will score those tokens given the LM states. Note that
the variable `scores` is the LM score after seeing the new
non-blank token.
"""
token_list = [] token_list = []
hs = [] hs = []
cs = [] cs = []
@ -2007,11 +2021,12 @@ def modified_beam_search_rnnlm_shallow_fusion(
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
assert new_token != 0, new_token assert new_token != 0, new_token
token_list.append([new_token]) token_list.append([new_token])
# store the LSTM states
hs.append(hyp.state[0]) hs.append(hyp.state[0])
cs.append(hyp.state[1]) cs.append(hyp.state[1])
# forward RNNLM to get new states and scores # forward RNNLM to get new states and scores
if len(token_list) != 0: if len(token_list) != 0:
tokens_to_score = ( tokens_to_score = (

View File

@ -228,7 +228,7 @@ def get_parser():
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG - fast_beam_search_nbest_LG
- modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
If you use fast_beam_search_nbest_LG, you have to specify If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`. `--lang-dir`, which should contain `LG.pt`.
""", """,
@ -354,7 +354,7 @@ def get_parser():
"--rnn-lm-exp-dir", "--rnn-lm-exp-dir",
type=str, type=str,
default="rnn_lm/exp", default="rnn_lm/exp",
help="""Used only when --method is rnn-lm. help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir. It specifies the path to RNN LM exp dir.
""", """,
) )
@ -363,7 +363,7 @@ def get_parser():
"--rnn-lm-epoch", "--rnn-lm-epoch",
type=int, type=int,
default=7, default=7,
help="""Used only when --method is rnn-lm. help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the checkpoint to use. It specifies the checkpoint to use.
""", """,
) )
@ -372,7 +372,7 @@ def get_parser():
"--rnn-lm-avg", "--rnn-lm-avg",
type=int, type=int,
default=2, default=2,
help="""Used only when --method is rnn-lm. help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the number of checkpoints to average. It specifies the number of checkpoints to average.
""", """,
) )