Add comments for rnn rescoring

This commit is contained in:
Erwan 2022-07-18 09:07:27 +02:00
parent 79d1316905
commit 30dd1479fd
3 changed files with 18 additions and 8 deletions

View File

@ -1219,8 +1219,10 @@ def fast_beam_search_with_nbest_rescoring(
temperature: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model. The shortest path within the
lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
@ -1366,14 +1368,17 @@ def fast_beam_search_with_nbest_rnn_rescoring(
sp: spm.SentencePieceProcessor,
word_table: k2.SymbolTable,
rnn_lm_model: torch.nn.Module,
rnn_lm_scale_list: List[float],
oov_word: str = "<UNK>",
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model and a rnn-lm.
The shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
@ -1400,6 +1405,10 @@ def fast_beam_search_with_nbest_rnn_rescoring(
The BPE model.
word_table:
The word symbol table.
rnn_lm_model:
A rnn-lm model used for LM rescoring
rnn_lm_scale_list:
A list of floats representing RNN score scales.
oov_word:
OOV words are replaced with this word.
use_double_scores:
@ -1501,9 +1510,7 @@ def fast_beam_search_with_nbest_rnn_rescoring(
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
x_tokens = sos_tokens.pad(
mode="constant", padding_value=model.decoder.blank_id
)
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
x_tokens = x_tokens.to(torch.int64)
@ -1517,7 +1524,7 @@ def fast_beam_search_with_nbest_rnn_rescoring(
ans: Dict[str, List[List[int]]] = {}
for n_scale in ngram_lm_scale_list:
for rnn_scale in ngram_lm_scale_list:
for rnn_scale in rnn_lm_scale_list:
key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}"
tot_scores = (
am_scores.values

View File

@ -605,6 +605,7 @@ def decode_one_batch(
sp=sp,
word_table=word_table,
rnn_lm_model=rnn_lm_model,
rnn_lm_scale_list=ngram_lm_scale_list,
use_double_scores=True,
nbest_scale=params.nbest_scale,
temperature=params.temperature,

View File

@ -1006,6 +1006,8 @@ def rescore_with_rnn_lm(
An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to extract from the given lattice for rescoring.
rnn_lm_model:
A rnn-lm model used for LM rescoring
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.