[WIP] Rnn-T LM nbest rescoring (#471)

This commit is contained in:
ezerhouni 2022-07-15 04:32:54 +02:00 committed by GitHub
parent c17233eca7
commit ffca1ae7fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 358 additions and 6 deletions

View File

@ -19,6 +19,7 @@ from dataclasses import dataclass
from typing import Dict, List, Optional
import k2
import sentencepiece as spm
import torch
from model import Transducer
@ -34,6 +35,7 @@ def fast_beam_search_one_best(
beam: float,
max_states: int,
max_contexts: int,
temperature: float = 1.0,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
@ -56,6 +58,8 @@ def fast_beam_search_one_best(
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
temperature:
Softmax temperature.
Returns:
Return the decoded result.
"""
@ -67,6 +71,7 @@ def fast_beam_search_one_best(
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)
best_path = one_best_decoding(lattice)
@ -85,6 +90,7 @@ def fast_beam_search_nbest_LG(
num_paths: int,
nbest_scale: float = 0.5,
use_double_scores: bool = True,
temperature: float = 1.0,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
@ -120,6 +126,8 @@ def fast_beam_search_nbest_LG(
use_double_scores:
True to use double precision for computation. False to use
single precision.
temperature:
Softmax temperature.
Returns:
Return the decoded result.
"""
@ -131,6 +139,7 @@ def fast_beam_search_nbest_LG(
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)
nbest = Nbest.from_lattice(
@ -201,6 +210,7 @@ def fast_beam_search_nbest(
num_paths: int,
nbest_scale: float = 0.5,
use_double_scores: bool = True,
temperature: float = 1.0,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
@ -236,6 +246,8 @@ def fast_beam_search_nbest(
use_double_scores:
True to use double precision for computation. False to use
single precision.
temperature:
Softmax temperature.
Returns:
Return the decoded result.
"""
@ -247,6 +259,7 @@ def fast_beam_search_nbest(
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)
nbest = Nbest.from_lattice(
@ -282,6 +295,7 @@ def fast_beam_search_nbest_oracle(
ref_texts: List[List[int]],
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
@ -321,7 +335,8 @@ def fast_beam_search_nbest_oracle(
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
temperature:
Softmax temperature.
Returns:
Return the decoded result.
"""
@ -333,6 +348,7 @@ def fast_beam_search_nbest_oracle(
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)
nbest = Nbest.from_lattice(
@ -373,6 +389,7 @@ def fast_beam_search(
beam: float,
max_states: int,
max_contexts: int,
temperature: float = 1.0,
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.
@ -392,6 +409,8 @@ def fast_beam_search(
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
temperature:
Softmax temperature.
Returns:
Return an FsaVec with axes [utt][state][arc] containing the decoded
lattice. Note: When the input graph is a TrivialGraph, the returned
@ -440,7 +459,7 @@ def fast_beam_search(
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
log_probs = (logits / temperature).log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
@ -783,6 +802,7 @@ def modified_beam_search(
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4,
temperature: float = 1.0,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -796,6 +816,8 @@ def modified_beam_search(
encoder_out before padding.
beam:
Number of active paths during the beam search.
temperature:
Softmax temperature.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
@ -879,7 +901,9 @@ def modified_beam_search(
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs = (logits / temperature).log_softmax(
dim=-1
) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
@ -1043,6 +1067,7 @@ def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
temperature: float = 1.0,
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -1056,6 +1081,8 @@ def beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
temperature:
Softmax temperature.
Returns:
Return the decoded result.
"""
@ -1132,7 +1159,7 @@ def beam_search(
)
# TODO(fangjun): Scale the blank posterior
log_prob = logits.log_softmax(dim=-1)
log_prob = (logits / temperature).log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
@ -1171,3 +1198,155 @@ def beam_search(
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
def fast_beam_search_with_nbest_rescoring(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
ngram_lm_scale_list: List[float],
num_paths: int,
G: k2.Fsa,
sp: spm.SentencePieceProcessor,
word_table: k2.SymbolTable,
oov_word: str = "<UNK>",
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi.
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
ngram_lm_scale_list:
A list of floats representing LM score scales.
num_paths:
Number of paths to extract from the decoded lattice.
G:
An FsaVec containing only a single FSA. It is an n-gram LM.
sp:
The BPE model.
word_table:
The word symbol table.
oov_word:
OOV words are replaced with this word.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
temperature:
Softmax temperature.
Returns:
Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
ngram LM scale value used during decoding, i.e., 0.1.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.
nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores
am_scores = nbest.tot_scores()
# Now we need to compute the LM scores of each path.
# (1) Get the token IDs of each Path. We assume the decoding_graph
# is an acceptor, i.e., lattice is also an acceptor
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc]
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous())
tokens = tokens.remove_values_leq(0) # remove -1 and 0
token_list: List[List[int]] = tokens.tolist()
word_list: List[List[str]] = sp.decode(token_list)
assert isinstance(oov_word, str), oov_word
assert oov_word in word_table, oov_word
oov_word_id = word_table[oov_word]
word_ids_list: List[List[int]] = []
for words in word_list:
this_word_ids = []
for w in words.split():
if w in word_table:
this_word_ids.append(word_table[w])
else:
this_word_ids.append(oov_word_id)
word_ids_list.append(this_word_ids)
word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device)
word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas)
num_unique_paths = len(word_ids_list)
b_to_a_map = torch.zeros(
num_unique_paths,
dtype=torch.int32,
device=lattice.device,
)
rescored_word_fsas = k2.intersect_device(
a_fsas=G,
b_fsas=word_fsas_with_self_loops,
b_to_a_map=b_to_a_map,
sorted_match_a=True,
ret_arc_maps=False,
)
rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas)
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
use_double_scores=True,
log_semiring=False,
)
ans: Dict[str, List[List[int]]] = {}
for s in ngram_lm_scale_list:
key = f"ngram_lm_scale_{s}"
tot_scores = am_scores.values + s * ngram_lm_scores
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
ans[key] = hyps
return ans

View File

@ -111,6 +111,7 @@ from beam_search import (
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
fast_beam_search_with_nbest_rescoring,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -312,6 +313,35 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="""Softmax temperature.
The output of the model is (logits / temperature).log_softmax().
""",
)
parser.add_argument(
"--lm-dir",
type=Path,
default=Path("./data/lm"),
help="""Used only when --decoding-method is
fast_beam_search_with_nbest_rescoring.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
parser.add_argument(
"--words-txt",
type=Path,
default=Path("./data/lang_bpe_500/words.txt"),
help="""Used only when --decoding-method is
fast_beam_search_with_nbest_rescoring.
It is the word table.
""",
)
add_model_arguments(parser)
return parser
@ -324,6 +354,7 @@ def decode_one_batch(
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -352,6 +383,11 @@ def decode_one_batch(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
G:
Optional. Used only when decoding method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
or fast_beam_search_with_nbest_rescoring.
It an FsaVec containing an acceptor.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -397,6 +433,7 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
temperature=params.temperature,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -411,6 +448,7 @@ def decode_one_batch(
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
temperature=params.temperature,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
@ -425,6 +463,7 @@ def decode_one_batch(
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
temperature=params.temperature,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -440,6 +479,7 @@ def decode_one_batch(
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
temperature=params.temperature,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -460,9 +500,32 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
temperature=params.temperature,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
ngram_lm_scale_list = [-0.5, -0.2, -0.1, -0.05, -0.02, 0]
ngram_lm_scale_list += [0.01, 0.02, 0.05]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.8]
ngram_lm_scale_list += [1.0, 1.5, 2.5, 3]
hyp_tokens = fast_beam_search_with_nbest_rescoring(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
ngram_lm_scale_list=ngram_lm_scale_list,
num_paths=params.num_paths,
G=G,
sp=sp,
word_table=word_table,
use_double_scores=True,
nbest_scale=params.nbest_scale,
temperature=params.temperature,
)
else:
batch_size = encoder_out.size(0)
@ -496,6 +559,7 @@ def decode_one_batch(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
f"temperature_{params.temperature}"
): hyps
}
elif params.decoding_method == "fast_beam_search":
@ -504,8 +568,23 @@ def decode_one_batch(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
f"temperature_{params.temperature}"
): hyps
}
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
prefix = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_"
f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}_"
f"temperature_{params.temperature}_"
)
ans: Dict[str, List[List[str]]] = {}
for key, hyp in hyp_tokens.items():
t: List[str] = sp.decode(hyp)
ans[prefix + key] = [s.split() for s in t]
return ans
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
@ -515,10 +594,14 @@ def decode_one_batch(
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
return {
(
f"beam_size_{params.beam_size}_"
f"temperature_{params.temperature}"
): hyps
}
def decode_dataset(
@ -528,6 +611,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -546,6 +630,11 @@ def decode_dataset(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
G:
Optional. Used only when decoding method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
or fast_beam_search_with_nbest_rescoring.
It's an FsaVec containing an acceptor.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -576,6 +665,7 @@ def decode_dataset(
word_table=word_table,
decoding_graph=decoding_graph,
batch=batch,
G=G,
)
for name, hyps in hyps_dict.items():
@ -642,6 +732,71 @@ def save_results(
logging.info(s)
def load_ngram_LM(
lm_dir: Path, word_table: k2.SymbolTable, device: torch.device
) -> k2.Fsa:
"""Read a ngram model from the given directory.
Args:
lm_dir:
It should contain either G_4_gram.pt or G_4_gram.fst.txt
word_table:
The word table mapping words to IDs and vice versa.
device:
The resulting FSA will be moved to this device.
Returns:
Return an FsaVec containing a single acceptor.
"""
lm_dir = Path(lm_dir)
assert lm_dir.is_dir(), f"{lm_dir} does not exist"
pt_file = lm_dir / "G_4_gram.pt"
if pt_file.is_file():
logging.info(f"Loading pre-compiled {pt_file}")
d = torch.load(pt_file, map_location=device)
G = k2.Fsa.from_dict(d)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
return G
txt_file = lm_dir / "G_4_gram.fst.txt"
assert txt_file.is_file(), f"{txt_file} does not exist"
logging.info(f"Loading {txt_file}")
logging.warning("It may take 8 minutes (Will be cached for later use).")
with open(txt_file) as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# Now G is an acceptor
first_word_disambig_id = word_table["#0"]
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
logging.info(f"Saving to {pt_file} for later use")
torch.save(G.as_dict(), pt_file)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
return G
@torch.no_grad()
def main():
parser = get_parser()
@ -660,6 +815,7 @@ def main():
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"fast_beam_search_with_nbest_rescoring",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -676,6 +832,7 @@ def main():
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-temperature-{params.temperature}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
@ -685,9 +842,11 @@ def main():
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
params.suffix += f"-temperature-{params.temperature}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-temperature-{params.temperature}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@ -760,6 +919,19 @@ def main():
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
logging.info(f"Loading word symbol table from {params.words_txt}")
word_table = k2.SymbolTable.from_file(params.words_txt)
G = load_ngram_LM(
lm_dir=params.lm_dir,
word_table=word_table,
device=device,
)
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
logging.info(f"G properties_str: {G.properties_str}")
else:
word_table = None
decoding_graph = k2.trivial_graph(
@ -792,6 +964,7 @@ def main():
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
G=G,
)
save_results(