Add fast_beam_search_nbest
This commit is contained in:
parent
ce26495238
commit
2456307acb
@ -19,6 +19,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
@ -34,6 +35,7 @@ def fast_beam_search_one_best(
|
|||||||
beam: float,
|
beam: float,
|
||||||
max_states: int,
|
max_states: int,
|
||||||
max_contexts: int,
|
max_contexts: int,
|
||||||
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -56,6 +58,8 @@ def fast_beam_search_one_best(
|
|||||||
Max states per stream per frame.
|
Max states per stream per frame.
|
||||||
max_contexts:
|
max_contexts:
|
||||||
Max contexts pre stream per frame.
|
Max contexts pre stream per frame.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
Return the decoded result.
|
||||||
"""
|
"""
|
||||||
@ -67,6 +71,7 @@ def fast_beam_search_one_best(
|
|||||||
beam=beam,
|
beam=beam,
|
||||||
max_states=max_states,
|
max_states=max_states,
|
||||||
max_contexts=max_contexts,
|
max_contexts=max_contexts,
|
||||||
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
best_path = one_best_decoding(lattice)
|
best_path = one_best_decoding(lattice)
|
||||||
@ -85,6 +90,7 @@ def fast_beam_search_nbest_LG(
|
|||||||
num_paths: int,
|
num_paths: int,
|
||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -131,6 +137,7 @@ def fast_beam_search_nbest_LG(
|
|||||||
beam=beam,
|
beam=beam,
|
||||||
max_states=max_states,
|
max_states=max_states,
|
||||||
max_contexts=max_contexts,
|
max_contexts=max_contexts,
|
||||||
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
nbest = Nbest.from_lattice(
|
nbest = Nbest.from_lattice(
|
||||||
@ -201,6 +208,7 @@ def fast_beam_search_nbest(
|
|||||||
num_paths: int,
|
num_paths: int,
|
||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -247,6 +255,7 @@ def fast_beam_search_nbest(
|
|||||||
beam=beam,
|
beam=beam,
|
||||||
max_states=max_states,
|
max_states=max_states,
|
||||||
max_contexts=max_contexts,
|
max_contexts=max_contexts,
|
||||||
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
nbest = Nbest.from_lattice(
|
nbest = Nbest.from_lattice(
|
||||||
@ -282,6 +291,7 @@ def fast_beam_search_nbest_oracle(
|
|||||||
ref_texts: List[List[int]],
|
ref_texts: List[List[int]],
|
||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -333,6 +343,7 @@ def fast_beam_search_nbest_oracle(
|
|||||||
beam=beam,
|
beam=beam,
|
||||||
max_states=max_states,
|
max_states=max_states,
|
||||||
max_contexts=max_contexts,
|
max_contexts=max_contexts,
|
||||||
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
nbest = Nbest.from_lattice(
|
nbest = Nbest.from_lattice(
|
||||||
@ -373,6 +384,7 @@ def fast_beam_search(
|
|||||||
beam: float,
|
beam: float,
|
||||||
max_states: int,
|
max_states: int,
|
||||||
max_contexts: int,
|
max_contexts: int,
|
||||||
|
temperature: float = 1.0,
|
||||||
) -> k2.Fsa:
|
) -> k2.Fsa:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -440,7 +452,7 @@ def fast_beam_search(
|
|||||||
project_input=False,
|
project_input=False,
|
||||||
)
|
)
|
||||||
logits = logits.squeeze(1).squeeze(1)
|
logits = logits.squeeze(1).squeeze(1)
|
||||||
log_probs = logits.log_softmax(dim=-1)
|
log_probs = (logits / temperature).log_softmax(dim=-1)
|
||||||
decoding_streams.advance(log_probs)
|
decoding_streams.advance(log_probs)
|
||||||
decoding_streams.terminate_and_flush_to_streams()
|
decoding_streams.terminate_and_flush_to_streams()
|
||||||
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||||
@ -783,6 +795,7 @@ def modified_beam_search(
|
|||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
|
|
||||||
@ -879,7 +892,9 @@ def modified_beam_search(
|
|||||||
|
|
||||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
log_probs = (logits / temperature).log_softmax(
|
||||||
|
dim=-1
|
||||||
|
) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
log_probs.add_(ys_log_probs)
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
@ -1043,6 +1058,7 @@ def beam_search(
|
|||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
|
temperature: float = 1.0,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
||||||
@ -1132,7 +1148,7 @@ def beam_search(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO(fangjun): Scale the blank posterior
|
# TODO(fangjun): Scale the blank posterior
|
||||||
log_prob = logits.log_softmax(dim=-1)
|
log_prob = (logits / temperature).log_softmax(dim=-1)
|
||||||
# log_prob is (1, 1, 1, vocab_size)
|
# log_prob is (1, 1, 1, vocab_size)
|
||||||
log_prob = log_prob.squeeze()
|
log_prob = log_prob.squeeze()
|
||||||
# Now log_prob is (vocab_size,)
|
# Now log_prob is (vocab_size,)
|
||||||
@ -1171,3 +1187,155 @@ def beam_search(
|
|||||||
best_hyp = B.get_most_probable(length_norm=True)
|
best_hyp = B.get_most_probable(length_norm=True)
|
||||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||||
return ys
|
return ys
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search_with_nbest_rescoring(
|
||||||
|
model: Transducer,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
ngram_lm_scale_list: List[float],
|
||||||
|
num_paths: int,
|
||||||
|
G: k2.Fsa,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
word_table: k2.SymbolTable,
|
||||||
|
oov_word: str = "<UNK>",
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
nbest_scale: float = 0.5,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
A lattice is first obtained using modified beam search, and then
|
||||||
|
the shortest path within the lattice is used as the final output.
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi.
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
ngram_lm_scale_list:
|
||||||
|
A list of floats representing LM score scales.
|
||||||
|
num_paths:
|
||||||
|
Number of paths to extract from the decoded lattice.
|
||||||
|
G:
|
||||||
|
An FsaVec containing only a single FSA. It is an n-gram LM.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
oov_word:
|
||||||
|
OOV words are replaced with this word.
|
||||||
|
use_double_scores:
|
||||||
|
True to use double precision for computation. False to use
|
||||||
|
single precision.
|
||||||
|
nbest_scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
|
Returns:
|
||||||
|
Return the decoded result in a dict, where the key has the form
|
||||||
|
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
|
||||||
|
ngram LM scale value used during decoding, i.e., 0.1.
|
||||||
|
"""
|
||||||
|
lattice = fast_beam_search(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=beam,
|
||||||
|
max_states=max_states,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
nbest = Nbest.from_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
nbest_scale=nbest_scale,
|
||||||
|
)
|
||||||
|
# at this point, nbest.fsa.scores are all zeros.
|
||||||
|
|
||||||
|
nbest = nbest.intersect(lattice)
|
||||||
|
# Now nbest.fsa.scores contains acoustic scores
|
||||||
|
|
||||||
|
am_scores = nbest.tot_scores()
|
||||||
|
|
||||||
|
# Now we need to compute the LM scores of each path.
|
||||||
|
# (1) Get the token IDs of each Path. We assume the decoding_graph
|
||||||
|
# is an acceptor, i.e., lattice is also an acceptor
|
||||||
|
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc]
|
||||||
|
|
||||||
|
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous())
|
||||||
|
tokens = tokens.remove_values_leq(0) # remove -1 and 0
|
||||||
|
|
||||||
|
token_list: List[List[int]] = tokens.tolist()
|
||||||
|
word_list: List[List[str]] = sp.decode(token_list)
|
||||||
|
|
||||||
|
assert isinstance(oov_word, str), oov_word
|
||||||
|
assert oov_word in word_table, oov_word
|
||||||
|
oov_word_id = word_table[oov_word]
|
||||||
|
|
||||||
|
word_ids_list: List[List[int]] = []
|
||||||
|
|
||||||
|
for words in word_list:
|
||||||
|
this_word_ids = []
|
||||||
|
for w in words.split():
|
||||||
|
if w in word_table:
|
||||||
|
this_word_ids.append(word_table[w])
|
||||||
|
else:
|
||||||
|
this_word_ids.append(oov_word_id)
|
||||||
|
word_ids_list.append(this_word_ids)
|
||||||
|
|
||||||
|
word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device)
|
||||||
|
word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas)
|
||||||
|
|
||||||
|
num_unique_paths = len(word_ids_list)
|
||||||
|
|
||||||
|
b_to_a_map = torch.zeros(
|
||||||
|
num_unique_paths,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=lattice.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
rescored_word_fsas = k2.intersect_device(
|
||||||
|
a_fsas=G,
|
||||||
|
b_fsas=word_fsas_with_self_loops,
|
||||||
|
b_to_a_map=b_to_a_map,
|
||||||
|
sorted_match_a=True,
|
||||||
|
ret_arc_maps=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas)
|
||||||
|
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
|
||||||
|
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
|
||||||
|
use_double_scores=True,
|
||||||
|
log_semiring=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
ans: Dict[str, List[List[int]]] = {}
|
||||||
|
for s in ngram_lm_scale_list:
|
||||||
|
key = f"ngram_lm_scale_{s}"
|
||||||
|
tot_scores = am_scores.values + s * ngram_lm_scores
|
||||||
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
|
||||||
|
ans[key] = hyps
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user