Add RNN-LM rescoring in fast beam search (#475)

This commit is contained in:
ezerhouni 2022-07-18 10:52:17 +02:00 committed by GitHub
parent aec222e2fe
commit 608473b4eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 325 additions and 11 deletions

View File

@ -24,7 +24,7 @@ import torch
from model import Transducer
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
from icefall.utils import add_eos, add_sos, get_texts
def fast_beam_search_one_best(
@ -46,7 +46,7 @@ def fast_beam_search_one_best(
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
@ -106,7 +106,7 @@ def fast_beam_search_nbest_LG(
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
@ -226,7 +226,7 @@ def fast_beam_search_nbest(
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
@ -311,7 +311,7 @@ def fast_beam_search_nbest_oracle(
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
@ -397,7 +397,7 @@ def fast_beam_search(
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
@ -1219,13 +1219,15 @@ def fast_beam_search_with_nbest_rescoring(
temperature: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model. The shortest path within the
lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
@ -1350,3 +1352,190 @@ def fast_beam_search_with_nbest_rescoring(
ans[key] = hyps
return ans
def fast_beam_search_with_nbest_rnn_rescoring(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
ngram_lm_scale_list: List[float],
num_paths: int,
G: k2.Fsa,
sp: spm.SentencePieceProcessor,
word_table: k2.SymbolTable,
rnn_lm_model: torch.nn.Module,
rnn_lm_scale_list: List[float],
oov_word: str = "<UNK>",
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, num_path are selected
and rescored using a given language model and a rnn-lm.
The shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi.
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
ngram_lm_scale_list:
A list of floats representing LM score scales.
num_paths:
Number of paths to extract from the decoded lattice.
G:
An FsaVec containing only a single FSA. It is an n-gram LM.
sp:
The BPE model.
word_table:
The word symbol table.
rnn_lm_model:
A rnn-lm model used for LM rescoring
rnn_lm_scale_list:
A list of floats representing RNN score scales.
oov_word:
OOV words are replaced with this word.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
temperature:
Softmax temperature.
Returns:
Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
ngram LM scale value used during decoding, i.e., 0.1.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.
nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores
am_scores = nbest.tot_scores()
# Now we need to compute the LM scores of each path.
# (1) Get the token IDs of each Path. We assume the decoding_graph
# is an acceptor, i.e., lattice is also an acceptor
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc]
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous())
tokens = tokens.remove_values_leq(0) # remove -1 and 0
token_list: List[List[int]] = tokens.tolist()
word_list: List[List[str]] = sp.decode(token_list)
assert isinstance(oov_word, str), oov_word
assert oov_word in word_table, oov_word
oov_word_id = word_table[oov_word]
word_ids_list: List[List[int]] = []
for words in word_list:
this_word_ids = []
for w in words.split():
if w in word_table:
this_word_ids.append(word_table[w])
else:
this_word_ids.append(oov_word_id)
word_ids_list.append(this_word_ids)
word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device)
word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas)
num_unique_paths = len(word_ids_list)
b_to_a_map = torch.zeros(
num_unique_paths,
dtype=torch.int32,
device=lattice.device,
)
rescored_word_fsas = k2.intersect_device(
a_fsas=G,
b_fsas=word_fsas_with_self_loops,
b_to_a_map=b_to_a_map,
sorted_match_a=True,
ret_arc_maps=False,
)
rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas)
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
use_double_scores=True,
log_semiring=False,
)
# Now RNN-LM
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("sos_id")
eos_id = sp.piece_to_id("eos_id")
sos_tokens = add_sos(tokens, sos_id)
tokens_eos = add_eos(tokens, eos_id)
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
x_tokens = x_tokens.to(torch.int64)
y_tokens = y_tokens.to(torch.int64)
sentence_lengths = sentence_lengths.to(torch.int64)
rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths)
assert rnn_lm_nll.ndim == 2
assert rnn_lm_nll.shape[0] == len(token_list)
rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1)
ans: Dict[str, List[List[int]]] = {}
for n_scale in ngram_lm_scale_list:
for rnn_scale in rnn_lm_scale_list:
key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}"
tot_scores = (
am_scores.values
+ n_scale * ngram_lm_scores
+ rnn_scale * rnn_lm_scores
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
ans[key] = hyps
return ans

View File

@ -112,6 +112,7 @@ from beam_search import (
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
fast_beam_search_with_nbest_rescoring,
fast_beam_search_with_nbest_rnn_rescoring,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -125,8 +126,10 @@ from icefall.checkpoint import (
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import (
AttributeDict,
load_averaged_model,
setup_logger,
store_transcripts,
str2bool,
@ -342,6 +345,62 @@ def get_parser():
""",
)
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is rnn-lm.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is rnn-lm.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is rnn-lm.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=True,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
add_model_arguments(parser)
return parser
@ -355,6 +414,7 @@ def decode_one_batch(
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
G: Optional[k2.Fsa] = None,
rnn_lm_model: torch.nn.Module = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -526,6 +586,30 @@ def decode_one_batch(
nbest_scale=params.nbest_scale,
temperature=params.temperature,
)
elif params.decoding_method == "fast_beam_search_with_nbest_rnn_rescoring":
ngram_lm_scale_list = [-0.5, -0.2, -0.1, -0.05, -0.02, 0]
ngram_lm_scale_list += [0.01, 0.02, 0.05]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.8]
ngram_lm_scale_list += [1.0, 1.5, 2.5, 3]
hyp_tokens = fast_beam_search_with_nbest_rnn_rescoring(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
ngram_lm_scale_list=ngram_lm_scale_list,
num_paths=params.num_paths,
G=G,
sp=sp,
word_table=word_table,
rnn_lm_model=rnn_lm_model,
rnn_lm_scale_list=ngram_lm_scale_list,
use_double_scores=True,
nbest_scale=params.nbest_scale,
temperature=params.temperature,
)
else:
batch_size = encoder_out.size(0)
@ -571,7 +655,10 @@ def decode_one_batch(
f"temperature_{params.temperature}"
): hyps
}
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
elif params.decoding_method in [
"fast_beam_search_with_nbest_rescoring",
"fast_beam_search_with_nbest_rnn_rescoring",
]:
prefix = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
@ -612,6 +699,7 @@ def decode_dataset(
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
G: Optional[k2.Fsa] = None,
rnn_lm_model: torch.nn.Module = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -666,6 +754,7 @@ def decode_dataset(
decoding_graph=decoding_graph,
batch=batch,
G=G,
rnn_lm_model=rnn_lm_model,
)
for name, hyps in hyps_dict.items():
@ -816,6 +905,7 @@ def main():
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"fast_beam_search_with_nbest_rescoring",
"fast_beam_search_with_nbest_rnn_rescoring",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -919,7 +1009,10 @@ def main():
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
elif params.decoding_method in [
"fast_beam_search_with_nbest_rescoring",
"fast_beam_search_with_nbest_rnn_rescoring",
]:
logging.info(f"Loading word symbol table from {params.words_txt}")
word_table = k2.SymbolTable.from_file(params.words_txt)
@ -932,14 +1025,43 @@ def main():
params.vocab_size - 1, device=device
)
logging.info(f"G properties_str: {G.properties_str}")
rnn_lm_model = None
if (
params.decoding_method
== "fast_beam_search_with_nbest_rnn_rescoring"
):
rnn_lm_model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
if params.rnn_lm_avg == 1:
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
else:
rnn_lm_model = load_averaged_model(
params.rnn_lm_exp_dir,
rnn_lm_model,
params.rnn_lm_epoch,
params.rnn_lm_avg,
device,
)
rnn_lm_model.eval()
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
rnn_lm_model = None
else:
decoding_graph = None
word_table = None
rnn_lm_model = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -965,6 +1087,7 @@ def main():
word_table=word_table,
decoding_graph=decoding_graph,
G=G,
rnn_lm_model=rnn_lm_model,
)
save_results(

View File

@ -1006,6 +1006,8 @@ def rescore_with_rnn_lm(
An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to extract from the given lattice for rescoring.
rnn_lm_model:
A rnn-lm model used for LM rescoring
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.