[WIP] Rnn-T LM nbest rescoring (#471)

This commit is contained in:
ezerhouni 2022-07-15 04:32:54 +02:00 committed by GitHub
parent c17233eca7
commit ffca1ae7fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 358 additions and 6 deletions

View File

@ -19,6 +19,7 @@ from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import k2 import k2
import sentencepiece as spm
import torch import torch
from model import Transducer from model import Transducer
@ -34,6 +35,7 @@ def fast_beam_search_one_best(
beam: float, beam: float,
max_states: int, max_states: int,
max_contexts: int, max_contexts: int,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -56,6 +58,8 @@ def fast_beam_search_one_best(
Max states per stream per frame. Max states per stream per frame.
max_contexts: max_contexts:
Max contexts pre stream per frame. Max contexts pre stream per frame.
temperature:
Softmax temperature.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -67,6 +71,7 @@ def fast_beam_search_one_best(
beam=beam, beam=beam,
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature,
) )
best_path = one_best_decoding(lattice) best_path = one_best_decoding(lattice)
@ -85,6 +90,7 @@ def fast_beam_search_nbest_LG(
num_paths: int, num_paths: int,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -120,6 +126,8 @@ def fast_beam_search_nbest_LG(
use_double_scores: use_double_scores:
True to use double precision for computation. False to use True to use double precision for computation. False to use
single precision. single precision.
temperature:
Softmax temperature.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -131,6 +139,7 @@ def fast_beam_search_nbest_LG(
beam=beam, beam=beam,
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature,
) )
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
@ -201,6 +210,7 @@ def fast_beam_search_nbest(
num_paths: int, num_paths: int,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -236,6 +246,8 @@ def fast_beam_search_nbest(
use_double_scores: use_double_scores:
True to use double precision for computation. False to use True to use double precision for computation. False to use
single precision. single precision.
temperature:
Softmax temperature.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -247,6 +259,7 @@ def fast_beam_search_nbest(
beam=beam, beam=beam,
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature,
) )
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
@ -282,6 +295,7 @@ def fast_beam_search_nbest_oracle(
ref_texts: List[List[int]], ref_texts: List[List[int]],
use_double_scores: bool = True, use_double_scores: bool = True,
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -321,7 +335,8 @@ def fast_beam_search_nbest_oracle(
nbest_scale: nbest_scale:
It's the scale applied to the lattice.scores. A smaller value It's the scale applied to the lattice.scores. A smaller value
yields more unique paths. yields more unique paths.
temperature:
Softmax temperature.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -333,6 +348,7 @@ def fast_beam_search_nbest_oracle(
beam=beam, beam=beam,
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature,
) )
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
@ -373,6 +389,7 @@ def fast_beam_search(
beam: float, beam: float,
max_states: int, max_states: int,
max_contexts: int, max_contexts: int,
temperature: float = 1.0,
) -> k2.Fsa: ) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -392,6 +409,8 @@ def fast_beam_search(
Max states per stream per frame. Max states per stream per frame.
max_contexts: max_contexts:
Max contexts pre stream per frame. Max contexts pre stream per frame.
temperature:
Softmax temperature.
Returns: Returns:
Return an FsaVec with axes [utt][state][arc] containing the decoded Return an FsaVec with axes [utt][state][arc] containing the decoded
lattice. Note: When the input graph is a TrivialGraph, the returned lattice. Note: When the input graph is a TrivialGraph, the returned
@ -440,7 +459,7 @@ def fast_beam_search(
project_input=False, project_input=False,
) )
logits = logits.squeeze(1).squeeze(1) logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) log_probs = (logits / temperature).log_softmax(dim=-1)
decoding_streams.advance(log_probs) decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams() decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist()) lattice = decoding_streams.format_output(encoder_out_lens.tolist())
@ -783,6 +802,7 @@ def modified_beam_search(
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -796,6 +816,8 @@ def modified_beam_search(
encoder_out before padding. encoder_out before padding.
beam: beam:
Number of active paths during the beam search. Number of active paths during the beam search.
temperature:
Softmax temperature.
Returns: Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance. for the i-th utterance.
@ -879,7 +901,9 @@ def modified_beam_search(
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs = (logits / temperature).log_softmax(
dim=-1
) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs) log_probs.add_(ys_log_probs)
@ -1043,6 +1067,7 @@ def beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
temperature: float = 1.0,
) -> List[int]: ) -> List[int]:
""" """
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -1056,6 +1081,8 @@ def beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam: beam:
Beam size. Beam size.
temperature:
Softmax temperature.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -1132,7 +1159,7 @@ def beam_search(
) )
# TODO(fangjun): Scale the blank posterior # TODO(fangjun): Scale the blank posterior
log_prob = logits.log_softmax(dim=-1) log_prob = (logits / temperature).log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size) # log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze() log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,) # Now log_prob is (vocab_size,)
@ -1171,3 +1198,155 @@ def beam_search(
best_hyp = B.get_most_probable(length_norm=True) best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys return ys
def fast_beam_search_with_nbest_rescoring(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
ngram_lm_scale_list: List[float],
num_paths: int,
G: k2.Fsa,
sp: spm.SentencePieceProcessor,
word_table: k2.SymbolTable,
oov_word: str = "<UNK>",
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi.
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
ngram_lm_scale_list:
A list of floats representing LM score scales.
num_paths:
Number of paths to extract from the decoded lattice.
G:
An FsaVec containing only a single FSA. It is an n-gram LM.
sp:
The BPE model.
word_table:
The word symbol table.
oov_word:
OOV words are replaced with this word.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
temperature:
Softmax temperature.
Returns:
Return the decoded result in a dict, where the key has the form
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
ngram LM scale value used during decoding, i.e., 0.1.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# at this point, nbest.fsa.scores are all zeros.
nbest = nbest.intersect(lattice)
# Now nbest.fsa.scores contains acoustic scores
am_scores = nbest.tot_scores()
# Now we need to compute the LM scores of each path.
# (1) Get the token IDs of each Path. We assume the decoding_graph
# is an acceptor, i.e., lattice is also an acceptor
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc]
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous())
tokens = tokens.remove_values_leq(0) # remove -1 and 0
token_list: List[List[int]] = tokens.tolist()
word_list: List[List[str]] = sp.decode(token_list)
assert isinstance(oov_word, str), oov_word
assert oov_word in word_table, oov_word
oov_word_id = word_table[oov_word]
word_ids_list: List[List[int]] = []
for words in word_list:
this_word_ids = []
for w in words.split():
if w in word_table:
this_word_ids.append(word_table[w])
else:
this_word_ids.append(oov_word_id)
word_ids_list.append(this_word_ids)
word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device)
word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas)
num_unique_paths = len(word_ids_list)
b_to_a_map = torch.zeros(
num_unique_paths,
dtype=torch.int32,
device=lattice.device,
)
rescored_word_fsas = k2.intersect_device(
a_fsas=G,
b_fsas=word_fsas_with_self_loops,
b_to_a_map=b_to_a_map,
sorted_match_a=True,
ret_arc_maps=False,
)
rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas)
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
use_double_scores=True,
log_semiring=False,
)
ans: Dict[str, List[List[int]]] = {}
for s in ngram_lm_scale_list:
key = f"ngram_lm_scale_{s}"
tot_scores = am_scores.values + s * ngram_lm_scores
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
ans[key] = hyps
return ans

View File

@ -111,6 +111,7 @@ from beam_search import (
fast_beam_search_nbest_LG, fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle, fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
fast_beam_search_with_nbest_rescoring,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -312,6 +313,35 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)", help="left context can be seen during decoding (in frames after subsampling)",
) )
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="""Softmax temperature.
The output of the model is (logits / temperature).log_softmax().
""",
)
parser.add_argument(
"--lm-dir",
type=Path,
default=Path("./data/lm"),
help="""Used only when --decoding-method is
fast_beam_search_with_nbest_rescoring.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
parser.add_argument(
"--words-txt",
type=Path,
default=Path("./data/lang_bpe_500/words.txt"),
help="""Used only when --decoding-method is
fast_beam_search_with_nbest_rescoring.
It is the word table.
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -324,6 +354,7 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -352,6 +383,11 @@ def decode_one_batch(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
G:
Optional. Used only when decoding method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
or fast_beam_search_with_nbest_rescoring.
It an FsaVec containing an acceptor.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -397,6 +433,7 @@ def decode_one_batch(
beam=params.beam, beam=params.beam,
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -411,6 +448,7 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
temperature=params.temperature,
) )
for hyp in hyp_tokens: for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp]) hyps.append([word_table[i] for i in hyp])
@ -425,6 +463,7 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -440,6 +479,7 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]), ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -460,9 +500,32 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
temperature=params.temperature,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
ngram_lm_scale_list = [-0.5, -0.2, -0.1, -0.05, -0.02, 0]
ngram_lm_scale_list += [0.01, 0.02, 0.05]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.8]
ngram_lm_scale_list += [1.0, 1.5, 2.5, 3]
hyp_tokens = fast_beam_search_with_nbest_rescoring(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
ngram_lm_scale_list=ngram_lm_scale_list,
num_paths=params.num_paths,
G=G,
sp=sp,
word_table=word_table,
use_double_scores=True,
nbest_scale=params.nbest_scale,
temperature=params.temperature,
)
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -496,6 +559,7 @@ def decode_one_batch(
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
f"temperature_{params.temperature}"
): hyps ): hyps
} }
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
@ -504,8 +568,23 @@ def decode_one_batch(
f"beam_{params.beam}_" f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_" f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}" f"max_states_{params.max_states}"
f"temperature_{params.temperature}"
): hyps ): hyps
} }
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
prefix = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}_"
f"num_paths_{params.num_paths}_"
f"nbest_scale_{params.nbest_scale}_"
f"temperature_{params.temperature}_"
)
ans: Dict[str, List[List[str]]] = {}
for key, hyp in hyp_tokens.items():
t: List[str] = sp.decode(hyp)
ans[prefix + key] = [s.split() for s in t]
return ans
elif "fast_beam_search" in params.decoding_method: elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_" key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_" key += f"max_contexts_{params.max_contexts}_"
@ -515,10 +594,14 @@ def decode_one_batch(
key += f"nbest_scale_{params.nbest_scale}" key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps} return {key: hyps}
else: else:
return {f"beam_size_{params.beam_size}": hyps} return {
(
f"beam_size_{params.beam_size}_"
f"temperature_{params.temperature}"
): hyps
}
def decode_dataset( def decode_dataset(
@ -528,6 +611,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -546,6 +630,11 @@ def decode_dataset(
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
G:
Optional. Used only when decoding method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
or fast_beam_search_with_nbest_rescoring.
It's an FsaVec containing an acceptor.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -576,6 +665,7 @@ def decode_dataset(
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
batch=batch, batch=batch,
G=G,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -642,6 +732,71 @@ def save_results(
logging.info(s) logging.info(s)
def load_ngram_LM(
lm_dir: Path, word_table: k2.SymbolTable, device: torch.device
) -> k2.Fsa:
"""Read a ngram model from the given directory.
Args:
lm_dir:
It should contain either G_4_gram.pt or G_4_gram.fst.txt
word_table:
The word table mapping words to IDs and vice versa.
device:
The resulting FSA will be moved to this device.
Returns:
Return an FsaVec containing a single acceptor.
"""
lm_dir = Path(lm_dir)
assert lm_dir.is_dir(), f"{lm_dir} does not exist"
pt_file = lm_dir / "G_4_gram.pt"
if pt_file.is_file():
logging.info(f"Loading pre-compiled {pt_file}")
d = torch.load(pt_file, map_location=device)
G = k2.Fsa.from_dict(d)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
return G
txt_file = lm_dir / "G_4_gram.fst.txt"
assert txt_file.is_file(), f"{txt_file} does not exist"
logging.info(f"Loading {txt_file}")
logging.warning("It may take 8 minutes (Will be cached for later use).")
with open(txt_file) as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# Now G is an acceptor
first_word_disambig_id = word_table["#0"]
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
logging.info(f"Saving to {pt_file} for later use")
torch.save(G.as_dict(), pt_file)
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
return G
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
@ -660,6 +815,7 @@ def main():
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
"fast_beam_search_with_nbest_rescoring",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -676,6 +832,7 @@ def main():
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-temperature-{params.temperature}"
if "nbest" in params.decoding_method: if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-num-paths-{params.num_paths}"
@ -685,9 +842,11 @@ def main():
params.suffix += ( params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}" f"-{params.decoding_method}-beam-size-{params.beam_size}"
) )
params.suffix += f"-temperature-{params.temperature}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-temperature-{params.temperature}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")
@ -760,6 +919,19 @@ def main():
torch.load(lg_filename, map_location=device) torch.load(lg_filename, map_location=device)
) )
decoding_graph.scores *= params.ngram_lm_scale decoding_graph.scores *= params.ngram_lm_scale
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
logging.info(f"Loading word symbol table from {params.words_txt}")
word_table = k2.SymbolTable.from_file(params.words_txt)
G = load_ngram_LM(
lm_dir=params.lm_dir,
word_table=word_table,
device=device,
)
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
logging.info(f"G properties_str: {G.properties_str}")
else: else:
word_table = None word_table = None
decoding_graph = k2.trivial_graph( decoding_graph = k2.trivial_graph(
@ -792,6 +964,7 @@ def main():
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
G=G,
) )
save_results( save_results(