diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py new file mode 100755 index 000000000..d87fdf1b9 --- /dev/null +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates LG from + + - L, the lexicon, built from lang_dir/L_disambig.pt + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G_3_gram.fst.txt + +The generated LG is saved in $lang_dir/LG.fst +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_LG(lang_dir: str) -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_500. + + Return: + An FST representing LG. + """ + + tokens = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt") + + assert "#0" in tokens + + first_token_disambig_id = tokens["#0"] + logging.info(f"first token disambig ID: {first_token_disambig_id}") + + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + if Path("data/lm/G_3_gram.pt").is_file(): + logging.info("Loading pre-compiled G_3_gram") + d = torch.load("data/lm/G_3_gram.pt") + G = k2.Fsa.from_dict(d) + else: + logging.info("Loading G_3_gram.fst.txt") + with open("data/lm/G_3_gram.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + del G.aux_labels + torch.save(G.as_dict(), "data/lm/G_3_gram.pt") + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Composing L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}, num_arcs: {LG.num_arcs}") + + del LG.aux_labels + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info( + f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}" + ) + + logging.info("Determinizing LG") + LG = k2.determinize(LG) + logging.info( + f"LG shape after k2.determinize: {LG.shape}, num_arcs: {LG.num_arcs}" + ) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + logging.info( + f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}" + ) + + logging.info("Removing disambiguation symbols on LG") + LG.labels[LG.labels >= first_token_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set LG.properties to None + LG.__dict__["_properties"] = None + + logging.info("Removing epsilons") + LG = k2.remove_epsilon(LG) + logging.info( + f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}" + ) + + logging.info("Connecting") + LG = k2.connect(LG) + logging.info( + f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}" + ) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + logging.info(f"LG properties: {LG.properties_str}") + # Possible properties is: + # "Valid|Nonempty|ArcSorted|EpsilonFree|MaybeAccessible|MaybeCoaccessible" + logging.info("Caution: LG is not deterministic!!!") + + return LG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + out_filename = lang_dir / "LG.pt" + + if out_filename.is_file(): + logging.info(f"{out_filename} already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + LG = compile_LG(lang_dir) + logging.info(f"Saving LG to {out_filename}") + torch.save(LG.as_dict(), out_filename) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/local/test_compile_lg.py b/egs/librispeech/ASR/local/test_compile_lg.py new file mode 100755 index 000000000..35eee0586 --- /dev/null +++ b/egs/librispeech/ASR/local/test_compile_lg.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./local/test_compile_lg.py +""" + +from pathlib import Path +from typing import List + +import k2 +import sentencepiece as spm +import torch + +lang_dir = Path("./data/lang_bpe_500") + + +def get_word_ids(word_table: k2.SymbolTable, s: str) -> List[int]: + """ + Args: + word_table: + Word symbol table. + s: + A string consisting of space(s) separated words. + Returns: + Return a list of word IDs. + """ + ans = [] + for w in s.split(): + ans.append(word_table[w]) + return ans + + +def main(): + assert lang_dir.exists(), f"{lang_dir} does not exist!" + LG = k2.Fsa.from_dict(torch.load(f"{lang_dir}/LG.pt", map_location="cpu")) + + sp = spm.SentencePieceProcessor() + sp.load(f"{lang_dir}/bpe.model") + + word_table = k2.SymbolTable.from_file(f"{lang_dir}/words.txt") + s = "HELLO WORLD" + token_ids = sp.encode(s) + + token_fsa = k2.linear_fsa(token_ids) + + fsa = k2.intersect(LG, token_fsa) + fsa = k2.connect(fsa) + print(k2.to_dot(fsa)) + print(fsa.properties_str) + print(LG.properties_str) + # You can use https://dreampuf.github.io/GraphvizOnline/ + # to visualize the output. + # + # You can see that the resulting fsa is not deterministic + # Note: LG is non-deterministic + # + # See https://shorturl.at/uIL69 + # for visualization of the above fsa. + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index c5efb733d..a390c3180 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -17,8 +17,10 @@ from dataclasses import dataclass from typing import Dict, List, Optional +import k2 import torch from model import Transducer +from shallow_fusion import shallow_fusion def greedy_search( @@ -111,6 +113,13 @@ class Hypothesis: # It contains only one entry. log_prob: torch.Tensor + # Used for shallow fusion + # The key of the dict is a state index into LG + # while the corresponding value is the LM score + # reaching this state. + # Note: The value tensor contains only a single entry + ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = (None,) + @property def key(self) -> str: """Return a string representation of self.ys""" @@ -149,6 +158,15 @@ class HypothesisList(object): torch.logaddexp( old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob ) + if hyp.ngram_state_and_scores is not None: + for state, score in hyp.ngram_state_and_scores.items(): + if ( + state in old_hyp.ngram_state_and_scores + and score > old_hyp.ngram_state_and_scores[state] + ): + old_hyp.ngram_state_and_scores[state] = score + else: + old_hyp.ngram_state_and_scores[state] = score else: self._data[key] = hyp @@ -318,6 +336,7 @@ def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + LG: Optional[k2.Fsa] = None, ) -> List[int]: """It limits the maximum number of symbols per frame to 1. @@ -328,9 +347,13 @@ def modified_beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + LG: + Optional. Used for shallow fusion. Returns: Return the decoded result. """ + enable_shallow_fusion = LG is not None + ngram_lm_scale = 0.8 assert encoder_out.ndim == 3 @@ -350,10 +373,19 @@ def modified_beam_search( T = encoder_out.size(1) B = HypothesisList() + + if enable_shallow_fusion: + ngram_state_and_scores = { + 0: torch.zeros(1, dtype=torch.float32, device=device) + } + else: + ngram_state_and_scores = None + B.add( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ngram_state_and_scores=ngram_state_and_scores, ) ) @@ -411,9 +443,33 @@ def modified_beam_search( new_token = topk_token_indexes[i] if new_token != blank_id: new_ys.append(new_token) + else: + ngram_state_and_scores = hyp.ngram_state_and_scores + new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + + if enable_shallow_fusion and new_token != blank_id: + ngram_state_and_scores = shallow_fusion( + LG, new_token, hyp.ngram_state_and_scores + ) + if len(ngram_state_and_scores) == 0: + continue + max_ngram_score = max(ngram_state_and_scores.values()) + new_log_prob += ngram_lm_scale * max_ngram_score + + # TODO: Get the maximum scores in ngram_state_and_scores + # and add it to new_log_prob + + new_hyp = Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + ngram_state_and_scores=ngram_state_and_scores, + ) + B.add(new_hyp) + if len(B) == 0: + for h in A: + B.add(h) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index c101d9397..d4da693a5 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -40,8 +40,9 @@ import argparse import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple +import k2 import sentencepiece as spm import torch import torch.nn as nn @@ -131,6 +132,13 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--LG", + type=str, + help="""Path to LG.pt for shallow fusion. + Used only when --decoding-method is modified_beam_search.""", + ) + return parser @@ -203,6 +211,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + LG: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -225,6 +234,9 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + LG: + Optional. Used for shallow fusion. Used only when params.decoding_method + is modified_beam_search. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -257,17 +269,24 @@ def decode_one_batch( ) elif params.decoding_method == "beam_search": hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, ) elif params.decoding_method == "modified_beam_search": hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + LG=LG, ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) hyps.append(sp.decode(hyp).split()) + for h in hyps: + print(" ".join(h)) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -280,6 +299,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + LG: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -292,6 +312,9 @@ def decode_dataset( The neural model. sp: The BPE model. + LG: + Optional. Used for shallow fusion. Used only when params.decoding_method + is modified_beam_search. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -320,6 +343,7 @@ def decode_dataset( model=model, sp=sp, batch=batch, + LG=LG, ) for name, hyps in hyps_dict.items(): @@ -419,6 +443,21 @@ def main(): logging.info(f"Device: {device}") + if params.LG is not None: + assert ( + params.decoding_method == "modified_beam_search" + ), "--LG is used only when --decoding_method=modified_beam_search" + logging.info(f"Loading LG from {params.LG}") + LG = k2.Fsa.from_dict(torch.load(params.LG, map_location=device)) + logging.info(f"LG properties: {LG.properties_str}") + logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}") + # If LG is created by local/compile_lg.py, then it should be epsilon + # free as well as arc sorted + assert "ArcSorted" in LG.properties_str + assert "EpsilonFree" in LG.properties_str + else: + LG = None + sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) @@ -467,6 +506,7 @@ def main(): params=params, model=model, sp=sp, + LG=LG, ) save_results( diff --git a/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py b/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py new file mode 100644 index 000000000..18ef253dd --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py @@ -0,0 +1,78 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import k2 +import torch + + +def shallow_fusion( + LG: k2.Fsa, + token: int, + state_and_scores: Dict[int, torch.Tensor], +) -> Dict[int, torch.Tensor]: + """ + Args: + LG: + An n-gram. It should be arc sorted and epsilon free. + token: + The input token ID. + state_and_scores: + The keys contain the current state we are in and the + values are the LM log_prob for reaching the corresponding + states from the start state. + Returns: + Return a new state_and_scores. + """ + row_splits = LG.arcs.row_splits(1) + arcs = LG.arcs.values() + + current_states = list(state_and_scores.keys()) + + ans = dict() + for s in current_states: + labels_begin = row_splits[s] + labels_end = row_splits[s + 1] + labels = LG.labels[labels_begin:labels_end].contiguous() + + # As LG is not deterministic, there may be multiple + # out-going arcs that with label equal to "token" + # + # Note: LG is arc sorted! + left = torch.bucketize(token, labels, right=False) + right = torch.bucketize(token, labels, right=True) + + if left >= right: + # There are no out-going arcs from this state + # that have label equal to "token" + continue + + # Now we have + # labels[i] == token + # for + # left <= i < right + + for i in range(left, right): + i += labels_begin + next_state = arcs[i][1].item() + score = LG.scores[i] + if next_state not in ans: + ans[next_state] = score + else: + ans[next_state] = max(score, ans[next_state]) + + return ans