mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Add comments for rnn rescoring
This commit is contained in:
parent
79d1316905
commit
30dd1479fd
@ -1219,8 +1219,10 @@ def fast_beam_search_with_nbest_rescoring(
|
||||
temperature: float = 1.0,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
A lattice is first obtained using modified beam search, and then
|
||||
the shortest path within the lattice is used as the final output.
|
||||
A lattice is first obtained using fast beam search, num_path are selected
|
||||
and rescored using a given language model. The shortest path within the
|
||||
lattice is used as the final output.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
@ -1366,14 +1368,17 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: k2.SymbolTable,
|
||||
rnn_lm_model: torch.nn.Module,
|
||||
rnn_lm_scale_list: List[float],
|
||||
oov_word: str = "<UNK>",
|
||||
use_double_scores: bool = True,
|
||||
nbest_scale: float = 0.5,
|
||||
temperature: float = 1.0,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
A lattice is first obtained using modified beam search, and then
|
||||
the shortest path within the lattice is used as the final output.
|
||||
A lattice is first obtained using fast beam search, num_path are selected
|
||||
and rescored using a given language model and a rnn-lm.
|
||||
The shortest path within the lattice is used as the final output.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
@ -1400,6 +1405,10 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
rnn_lm_model:
|
||||
A rnn-lm model used for LM rescoring
|
||||
rnn_lm_scale_list:
|
||||
A list of floats representing RNN score scales.
|
||||
oov_word:
|
||||
OOV words are replaced with this word.
|
||||
use_double_scores:
|
||||
@ -1501,9 +1510,7 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
|
||||
sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
|
||||
|
||||
x_tokens = sos_tokens.pad(
|
||||
mode="constant", padding_value=model.decoder.blank_id
|
||||
)
|
||||
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
|
||||
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
x_tokens = x_tokens.to(torch.int64)
|
||||
@ -1517,7 +1524,7 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
|
||||
ans: Dict[str, List[List[int]]] = {}
|
||||
for n_scale in ngram_lm_scale_list:
|
||||
for rnn_scale in ngram_lm_scale_list:
|
||||
for rnn_scale in rnn_lm_scale_list:
|
||||
key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}"
|
||||
tot_scores = (
|
||||
am_scores.values
|
||||
|
@ -605,6 +605,7 @@ def decode_one_batch(
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
rnn_lm_model=rnn_lm_model,
|
||||
rnn_lm_scale_list=ngram_lm_scale_list,
|
||||
use_double_scores=True,
|
||||
nbest_scale=params.nbest_scale,
|
||||
temperature=params.temperature,
|
||||
|
@ -1006,6 +1006,8 @@ def rescore_with_rnn_lm(
|
||||
An FsaVec with axes [utt][state][arc].
|
||||
num_paths:
|
||||
Number of paths to extract from the given lattice for rescoring.
|
||||
rnn_lm_model:
|
||||
A rnn-lm model used for LM rescoring
|
||||
model:
|
||||
A transformer model. See the class "Transformer" in
|
||||
conformer_ctc/transformer.py for its interface.
|
||||
|
Loading…
x
Reference in New Issue
Block a user