diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py new file mode 100755 index 000000000..45c4b7f5f --- /dev/null +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates LG from + + - L, the lexicon, built from lang_dir/L_disambig.pt + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G_3_gram.fst.txt + +The generated LG is saved in $lang_dir/LG.pt +""" +import argparse +import logging +from pathlib import Path + +import k2 +import torch + +from icefall.lexicon import Lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_LG(lang_dir: str) -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + + Return: + An FSA representing LG. + """ + lexicon = Lexicon(lang_dir) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + if Path("data/lm/G_3_gram.pt").is_file(): + logging.info("Loading pre-compiled G_3_gram") + d = torch.load("data/lm/G_3_gram.pt") + G = k2.Fsa.from_dict(d) + else: + logging.info("Loading G_3_gram.fst.txt") + with open("data/lm/G_3_gram.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + torch.save(G.as_dict(), "data/lm/G_3_gram.pt") + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Intersecting L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}") + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info(f"LG shape after k2.connect: {LG.shape}") + + logging.info(type(LG.aux_labels)) + logging.info("Determinizing LG") + + LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing) + logging.info(type(LG.aux_labels)) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + + logging.info("Removing disambiguation symbols on LG") + + LG.labels[LG.labels >= first_token_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set LG.properties to None + LG.__dict__["_properties"] = None + + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 + + LG = k2.remove_epsilon(LG) + logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") + + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + return LG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + if (lang_dir / "LG.pt").is_file(): + logging.info(f"{lang_dir}/LG.pt already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + LG = compile_LG(lang_dir) + logging.info(f"Saving LG.pt to {lang_dir}") + torch.save(LG.as_dict(), f"{lang_dir}/LG.pt") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 1bbf7bbcf..6b61c6b57 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -242,3 +242,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then ./local/compile_hlg.py --lang-dir $lang_dir done fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Compile LG" + ./local/compile_lg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index ef1f399c6..100aeaa6e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -22,7 +22,7 @@ import k2 import torch from model import Transducer -from icefall.decode import one_best_decoding +from icefall.decode import Nbest, one_best_decoding from icefall.utils import get_texts @@ -34,6 +34,7 @@ def fast_beam_search( beam: float, max_states: int, max_contexts: int, + use_max: bool = False, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. @@ -53,6 +54,9 @@ def fast_beam_search( Max states per stream per frame. max_contexts: Max contexts pre stream per frame. + use_max: + True to use max operation to select the hypothesis with the largest + log_prob when there are duplicate hypotheses; False to use log-add. Returns: Return the decoded result. """ @@ -104,9 +108,67 @@ def fast_beam_search( decoding_streams.terminate_and_flush_to_streams() lattice = decoding_streams.format_output(encoder_out_lens.tolist()) - best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - return hyps + if use_max: + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + return hyps + else: + num_paths = 200 + use_double_scores = True + nbest_scale = 0.8 + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # The following code is modified from nbest.intersect() + word_fsa = k2.invert(nbest.fsa) + if hasattr(lattice, "aux_labels"): + # delete token IDs as it is not needed + del word_fsa.aux_labels + word_fsa.scores.zero_() + + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) + path_to_utt_map = nbest.shape.row_ids(1) + + if hasattr(lattice, "aux_labels"): + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + else: + inv_lattice = k2.arc_sort(lattice) + + if inv_lattice.shape[0] == 1: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=torch.zeros_like(path_to_utt_map), + sorted_match_a=True, + ) + else: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_utt_map, + sorted_match_a=True, + ) + + # path_lattice has word IDs as labels and token IDs as aux_labels + path_lattice = k2.top_sort(k2.connect(path_lattice)) + + tot_scores = path_lattice.get_tot_scores( + use_double_scores=use_double_scores, log_semiring=True + ) + + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + best_hyp_indexes = ragged_tot_scores.argmax() + + best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) + hyps = get_texts(best_path) + return hyps def greedy_search( @@ -280,7 +342,7 @@ class HypothesisList(object): def data(self) -> Dict[str, Hypothesis]: return self._data - def add(self, hyp: Hypothesis) -> None: + def add(self, hyp: Hypothesis, use_max: bool = False) -> None: """Add a Hypothesis to `self`. If `hyp` already exists in `self`, its probability is updated using @@ -289,13 +351,20 @@ class HypothesisList(object): Args: hyp: The hypothesis to be added. + use_max: + True to select the hypothesis with the larger log_prob in case there + already exists a hypothesis whose `ys` equals to `hyp.ys`. + False to use log_add. """ key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + if use_max: + old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) + else: + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -403,6 +472,7 @@ def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + use_max: bool = False, ) -> List[List[int]]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. @@ -413,6 +483,9 @@ def modified_beam_search( Output from the encoder. Its shape is (N, T, C). beam: Number of active paths during the beam search. + use_max: + True to use max operation to select the hypothesis with the largest + log_prob when there are duplicate hypotheses; False to use log-add. Returns: Return a list-of-list of token IDs. ans[i] is the decoding results for the i-th utterance. @@ -432,7 +505,8 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) + ), + use_max=use_max, ) for t in range(T): @@ -517,6 +591,7 @@ def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + use_max: bool = False, ) -> List[int]: """It limits the maximum number of symbols per frame to 1. @@ -532,6 +607,9 @@ def _deprecated_modified_beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + use_max: + True to use max operation to select the hypothesis with the largest + log_prob when there are duplicate hypotheses; False to use log-add. Returns: Return the decoded result. """ @@ -553,12 +631,13 @@ def _deprecated_modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) + ), + use_max=use_max, ) for t in range(T): # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + current_encoder_out = encoder_out[:, t:t + 1, :].unsqueeze(2) # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) # fmt: on A = list(B) @@ -611,7 +690,7 @@ def _deprecated_modified_beam_search( new_ys.append(new_token) new_log_prob = topk_log_probs[i] new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B.add(new_hyp) + B.add(new_hyp, use_max=use_max) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks @@ -623,6 +702,7 @@ def beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + use_max: bool = False, ) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -636,6 +716,9 @@ def beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + use_max: + True to use max operation to select the hypothesis with the largest + log_prob when there are duplicate hypotheses; False to use log-add. Returns: Return the decoded result. """ @@ -661,7 +744,9 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + B.add( + Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max + ) max_sym_per_utt = 20000 @@ -720,7 +805,10 @@ def beam_search( new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys - B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) + B.add( + Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob), + use_max=use_max, + ) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) @@ -729,7 +817,10 @@ def beam_search( continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v - A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + A.add( + Hypothesis(ys=new_ys, log_prob=new_log_prob), + use_max=use_max, + ) # Check whether B contains more than "beam" elements more probable # than the most probable in A diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 0e3b0f197..5082ab71a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -53,6 +53,19 @@ Usage: --beam 4 \ --max-contexts 4 \ --max-states 8 + +(5) fast beam search using LG +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --use-LG True \ + --use-max False \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 8 \ + --max-contexts 8 \ + --max-states 64 """ @@ -81,10 +94,12 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -136,6 +151,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -167,6 +189,36 @@ def get_parser(): Used only when --decoding-method is fast_beam_search""", ) + parser.add_argument( + "--use-LG", + type=str2bool, + default=False, + help="""Whether to use an LG graph for FSA-based beam search. + Used only when --decoding_method is fast_beam_search. If setting true, + it assumes there is an LG.pt file in lang_dir.""", + ) + + parser.add_argument( + "--use-max", + type=str2bool, + default=False, + help="""If True, use max-op to select the hypothesis that have the + max log_prob in case of duplicate hypotheses. + If False, use log_add. + Used only for beam_search, modified_beam_search, and fast_beam_search + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search. + It specifies the scale for n-gram LM scores. + """, + ) + parser.add_argument( "--max-contexts", type=int, @@ -206,6 +258,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the @@ -229,6 +282,8 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search. @@ -260,9 +315,14 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + use_max=params.use_max, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.use_LG: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + else: + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -278,6 +338,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, beam=params.beam_size, + use_max=params.use_max, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -299,6 +360,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size, + use_max=params.use_max, ) else: raise ValueError( @@ -325,6 +387,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -338,6 +401,8 @@ def decode_dataset( The neural model. sp: The BPE model. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search. @@ -368,8 +433,9 @@ def decode_dataset( params=params, model=model, sp=sp, - decoding_graph=decoding_graph, batch=batch, + word_table=word_table, + decoding_graph=decoding_graph, ) for name, hyps in hyps_dict.items(): @@ -460,13 +526,16 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if "fast_beam_search" in params.decoding_method: + params.suffix += f"-use-LG-{params.use_LG}" params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"-use-max-{params.use_max}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" ) + params.suffix += f"-use-max-{params.use_max}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -527,9 +596,21 @@ def main(): model.device = device if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + if params.use_LG: + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + decoding_graph = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/LG.pt", map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None + word_table = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -551,6 +632,7 @@ def main(): params=params, model=model, sp=sp, + word_table=word_table, decoding_graph=decoding_graph, )