From 8bb2f01b118cd5b879aaa18536d03394ddb93609 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 30 Mar 2022 10:52:47 +0800 Subject: [PATCH] Add LG decoding --- egs/librispeech/ASR/local/compile_lg.py | 141 ++++++++++++++++++ .../beam_search.py | 121 +++++++++++++-- .../ASR/pruned_transducer_stateless/decode.py | 92 +++++++++++- 3 files changed, 334 insertions(+), 20 deletions(-) create mode 100644 egs/librispeech/ASR/local/compile_lg.py diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py new file mode 100644 index 000000000..c8ab31729 --- /dev/null +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates LG from + + - L, the lexicon, built from lang_dir/L_disambig.pt + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G_3_gram.fst.txt + +The generated LG is saved in $lang_dir/LG.pt +""" +import argparse +import logging +from pathlib import Path + +import k2 +import torch + +from icefall.lexicon import Lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_LG(lang_dir: str) -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + + Return: + An FSA representing LG. + """ + lexicon = Lexicon(lang_dir) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + if Path("data/lm/G_3_gram.pt").is_file(): + logging.info("Loading pre-compiled G_3_gram") + d = torch.load("data/lm/G_3_gram.pt") + G = k2.Fsa.from_dict(d) + else: + logging.info("Loading G_3_gram.fst.txt") + with open("data/lm/G_3_gram.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + torch.save(G.as_dict(), "data/lm/G_3_gram.pt") + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Intersecting L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}") + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info(f"LG shape after k2.connect: {LG.shape}") + + logging.info(type(LG.aux_labels)) + logging.info("Determinizing LG") + + LG = k2.determinize(LG) + logging.info(type(LG.aux_labels)) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + + logging.info("Removing disambiguation symbols on LG") + + LG.labels[LG.labels >= first_token_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set LG.properties to None + LG.__dict__["_properties"] = None + + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 + + LG = k2.remove_epsilon(LG) + logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") + + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + return LG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + if (lang_dir / "LG.pt").is_file(): + logging.info(f"{lang_dir}/LG.pt already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + LG = compile_LG(lang_dir) + logging.info(f"Saving LG.pt to {lang_dir}") + torch.save(LG.as_dict(), f"{lang_dir}/LG.pt") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 815e1c02a..ea6f7c298 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -21,7 +21,7 @@ import k2 import torch from model import Transducer -from icefall.decode import one_best_decoding +from icefall.decode import Nbest, one_best_decoding from icefall.utils import get_texts @@ -33,6 +33,7 @@ def fast_beam_search( beam: float, max_states: int, max_contexts: int, + use_max: bool = False, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. @@ -52,6 +53,9 @@ def fast_beam_search( Max states per stream per frame. max_contexts: Max contexts pre stream per frame. + use_max: + True to use max operation to select the hypothesis with the largest + log_prob when there are duplicate hypotheses; False to use log-add. Returns: Return the decoded result. """ @@ -98,9 +102,67 @@ def fast_beam_search( decoding_streams.terminate_and_flush_to_streams() lattice = decoding_streams.format_output(encoder_out_lens.tolist()) - best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - return hyps + if use_max: + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + return hyps + else: + num_paths = 200 + use_double_scores = True + nbest_scale = 0.8 + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # The following code is modified from nbest.intersect() + word_fsa = k2.invert(nbest.fsa) + if hasattr(lattice, "aux_labels"): + # delete token IDs as it is not needed + del word_fsa.aux_labels + word_fsa.scores.zero_() + + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) + path_to_utt_map = nbest.shape.row_ids(1) + + if hasattr(lattice, "aux_labels"): + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + else: + inv_lattice = k2.arc_sort(lattice) + + if inv_lattice.shape[0] == 1: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=torch.zeros_like(path_to_utt_map), + sorted_match_a=True, + ) + else: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_utt_map, + sorted_match_a=True, + ) + + # path_lattice has word IDs as labels and token IDs as aux_labels + path_lattice = k2.top_sort(k2.connect(path_lattice)) + + tot_scores = path_lattice.get_tot_scores( + use_double_scores=use_double_scores, log_semiring=True + ) + + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + best_hyp_indexes = ragged_tot_scores.argmax() + + best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) + hyps = get_texts(best_path) + return hyps def greedy_search( @@ -272,7 +334,7 @@ class HypothesisList(object): def data(self) -> Dict[str, Hypothesis]: return self._data - def add(self, hyp: Hypothesis) -> None: + def add(self, hyp: Hypothesis, use_max: bool = False) -> None: """Add a Hypothesis to `self`. If `hyp` already exists in `self`, its probability is updated using @@ -281,13 +343,20 @@ class HypothesisList(object): Args: hyp: The hypothesis to be added. + use_max: + True to select the hypothesis with the larger log_prob in case there + already exists a hypothesis whose `ys` equals to `hyp.ys`. + False to use log_add. """ key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + if use_max: + old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) + else: + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -395,6 +464,7 @@ def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + use_max: bool = False, ) -> List[List[int]]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. @@ -405,6 +475,9 @@ def modified_beam_search( Output from the encoder. Its shape is (N, T, C). beam: Number of active paths during the beam search. + use_max: + True to use max operation to select the hypothesis with the largest + log_prob when there are duplicate hypotheses; False to use log-add. Returns: Return a list-of-list of token IDs. ans[i] is the decoding results for the i-th utterance. @@ -423,7 +496,8 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) + ), + use_max=use_max, ) for t in range(T): @@ -508,6 +582,7 @@ def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + use_max: bool = False, ) -> List[int]: """It limits the maximum number of symbols per frame to 1. @@ -523,6 +598,9 @@ def _deprecated_modified_beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + use_max: + True to use max operation to select the hypothesis with the largest + log_prob when there are duplicate hypotheses; False to use log-add. Returns: Return the decoded result. """ @@ -543,12 +621,13 @@ def _deprecated_modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) + ), + use_max=use_max, ) for t in range(T): # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + current_encoder_out = encoder_out[:, t:t + 1, :].unsqueeze(2) # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) # fmt: on A = list(B) @@ -601,7 +680,7 @@ def _deprecated_modified_beam_search( new_ys.append(new_token) new_log_prob = topk_log_probs[i] new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B.add(new_hyp) + B.add(new_hyp, use_max=use_max) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks @@ -613,6 +692,7 @@ def beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + use_max: bool = False, ) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -626,6 +706,9 @@ def beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + use_max: + True to use max operation to select the hypothesis with the largest + log_prob when there are duplicate hypotheses; False to use log-add. Returns: Return the decoded result. """ @@ -650,7 +733,9 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) + B.add( + Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max + ) max_sym_per_utt = 20000 @@ -709,7 +794,10 @@ def beam_search( new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys - B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) + B.add( + Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob), + use_max=use_max, + ) # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) @@ -718,7 +806,10 @@ def beam_search( continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v - A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) + A.add( + Hypothesis(ys=new_ys, log_prob=new_log_prob), + use_max=use_max, + ) # Check whether B contains more than "beam" elements more probable # than the most probable in A diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 8e924bf96..21705848d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -53,6 +53,19 @@ Usage: --beam 4 \ --max-contexts 4 \ --max-states 8 + +(5) fast beam search using LG +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --use-LG True \ + --use-max False \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 8 \ + --max-contexts 8 \ + --max-states 64 """ @@ -81,10 +94,12 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -135,6 +150,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -166,6 +188,36 @@ def get_parser(): Used only when --decoding-method is fast_beam_search""", ) + parser.add_argument( + "--use-LG", + type=str2bool, + default=False, + help="""Whether to use an LG graph for FSA-based beam search. + Used only when --decoding_method is fast_beam_search. If setting true, + it assumes there is an LG.pt file in lang_dir.""", + ) + + parser.add_argument( + "--use-max", + type=str2bool, + default=False, + help="""If True, use max-op to select the hypothesis that have the + max log_prob in case of duplicate hypotheses. + If False, use log_add. + Used only for beam_search, modified_beam_search, and fast_beam_search + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search. + It specifies the scale for n-gram LM scores. + """, + ) + parser.add_argument( "--max-contexts", type=int, @@ -205,6 +257,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the @@ -228,6 +281,8 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search. @@ -259,9 +314,14 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + use_max=params.use_max, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + if params.use_LG: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + else: + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -277,6 +337,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, beam=params.beam_size, + use_max=params.use_max, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -298,6 +359,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size, + use_max=params.use_max, ) else: raise ValueError( @@ -324,6 +386,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -337,6 +400,8 @@ def decode_dataset( The neural model. sp: The BPE model. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search. @@ -367,8 +432,9 @@ def decode_dataset( params=params, model=model, sp=sp, - decoding_graph=decoding_graph, batch=batch, + word_table=word_table, + decoding_graph=decoding_graph, ) for name, hyps in hyps_dict.items(): @@ -455,11 +521,14 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if "fast_beam_search" in params.decoding_method: + params.suffix += f"-use-LG-{params.use_LG}" params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"-use-max-{params.use_max}" elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" + params.suffix += f"-beam-size-{params.beam_size}" + params.suffix += f"-use-max-{params.use_max}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -507,9 +576,21 @@ def main(): model.device = device if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + if params.use_LG: + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + decoding_graph = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/LG.pt", map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None + word_table = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -531,6 +612,7 @@ def main(): params=params, model=model, sp=sp, + word_table=word_table, decoding_graph=decoding_graph, )