Add fast_beam_search_nbest

This commit is contained in:
Erwan 2022-07-07 15:11:20 +02:00
parent ce26495238
commit 2456307acb

View File

@ -19,6 +19,7 @@ from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import k2 import k2
import sentencepiece as spm
import torch import torch
from model import Transducer from model import Transducer
@ -34,6 +35,7 @@ def fast_beam_search_one_best(
beam: float, beam: float,
max_states: int, max_states: int,
max_contexts: int, max_contexts: int,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -56,6 +58,8 @@ def fast_beam_search_one_best(
Max states per stream per frame. Max states per stream per frame.
max_contexts: max_contexts:
Max contexts pre stream per frame. Max contexts pre stream per frame.
temperature:
Softmax temperature.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -67,6 +71,7 @@ def fast_beam_search_one_best(
beam=beam, beam=beam,
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature,
) )
best_path = one_best_decoding(lattice) best_path = one_best_decoding(lattice)
@ -85,6 +90,7 @@ def fast_beam_search_nbest_LG(
num_paths: int, num_paths: int,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -131,6 +137,7 @@ def fast_beam_search_nbest_LG(
beam=beam, beam=beam,
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature,
) )
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
@ -201,6 +208,7 @@ def fast_beam_search_nbest(
num_paths: int, num_paths: int,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -247,6 +255,7 @@ def fast_beam_search_nbest(
beam=beam, beam=beam,
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature,
) )
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
@ -282,6 +291,7 @@ def fast_beam_search_nbest_oracle(
ref_texts: List[List[int]], ref_texts: List[List[int]],
use_double_scores: bool = True, use_double_scores: bool = True,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -333,6 +343,7 @@ def fast_beam_search_nbest_oracle(
beam=beam, beam=beam,
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature,
) )
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
@ -373,6 +384,7 @@ def fast_beam_search(
beam: float, beam: float,
max_states: int, max_states: int,
max_contexts: int, max_contexts: int,
temperature: float = 1.0,
) -> k2.Fsa: ) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -440,7 +452,7 @@ def fast_beam_search(
project_input=False, project_input=False,
) )
logits = logits.squeeze(1).squeeze(1) logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) log_probs = (logits / temperature).log_softmax(dim=-1)
decoding_streams.advance(log_probs) decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams() decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist()) lattice = decoding_streams.format_output(encoder_out_lens.tolist())
@ -783,6 +795,7 @@ def modified_beam_search(
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -879,7 +892,9 @@ def modified_beam_search(
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs = (logits / temperature).log_softmax(
dim=-1
) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs) log_probs.add_(ys_log_probs)
@ -1043,6 +1058,7 @@ def beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0,
) -> List[int]: ) -> List[int]:
""" """
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -1132,7 +1148,7 @@ def beam_search(
) )
# TODO(fangjun): Scale the blank posterior # TODO(fangjun): Scale the blank posterior
log_prob = logits.log_softmax(dim=-1) log_prob = (logits / temperature).log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size) # log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze() log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,) # Now log_prob is (vocab_size,)
@ -1171,3 +1187,155 @@ def beam_search(
best_hyp = B.get_most_probable(length_norm=True) best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys return ys
def fast_beam_search_with_nbest_rescoring(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
ngram_lm_scale_list: List[float],
num_paths: int,
G: k2.Fsa,
sp: spm.SentencePieceProcessor,
word_table: k2.SymbolTable,
oov_word: str = "<UNK>",
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi.
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
ngram_lm_scale_list:
A list of floats representing LM score scales.
num_paths:
Number of paths to extract from the decoded lattice.
G:
An FsaVec containing only a single FSA. It is an n-gram LM.
sp:
The BPE model.
word_table:
The word symbol table.
oov_word:
OOV words are replaced with this word.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
temperature:
Softmax temperature.
Returns:
Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
ngram LM scale value used during decoding, i.e., 0.1.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.
nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores
am_scores = nbest.tot_scores()
# Now we need to compute the LM scores of each path.
# (1) Get the token IDs of each Path. We assume the decoding_graph
# is an acceptor, i.e., lattice is also an acceptor
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc]
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous())
tokens = tokens.remove_values_leq(0) # remove -1 and 0
token_list: List[List[int]] = tokens.tolist()
word_list: List[List[str]] = sp.decode(token_list)
assert isinstance(oov_word, str), oov_word
assert oov_word in word_table, oov_word
oov_word_id = word_table[oov_word]
word_ids_list: List[List[int]] = []
for words in word_list:
this_word_ids = []
for w in words.split():
if w in word_table:
this_word_ids.append(word_table[w])
else:
this_word_ids.append(oov_word_id)
word_ids_list.append(this_word_ids)
word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device)
word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas)
num_unique_paths = len(word_ids_list)
b_to_a_map = torch.zeros(
num_unique_paths,
dtype=torch.int32,
device=lattice.device,
)
rescored_word_fsas = k2.intersect_device(
a_fsas=G,
b_fsas=word_fsas_with_self_loops,
b_to_a_map=b_to_a_map,
sorted_match_a=True,
ret_arc_maps=False,
)
rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas)
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
use_double_scores=True,
log_semiring=False,
)
ans: Dict[str, List[List[int]]] = {}
for s in ngram_lm_scale_list:
key = f"ngram_lm_scale_{s}"
tot_scores = am_scores.values + s * ngram_lm_scores
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
ans[key] = hyps
return ans