From 954b4efff3ee970d22d861ba6b3930e7cdc763e7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 8 Feb 2022 20:40:45 +0800 Subject: [PATCH 1/3] WIP: Use shallow fusion in modified beam search. --- egs/librispeech/ASR/local/compile_lg.py | 160 ++++++++++++++++++ egs/librispeech/ASR/local/test_compile_lg.py | 79 +++++++++ .../ASR/transducer_stateless/beam_search.py | 58 ++++++- .../ASR/transducer_stateless/decode.py | 46 ++++- .../transducer_stateless/shallow_fusion.py | 78 +++++++++ 5 files changed, 417 insertions(+), 4 deletions(-) create mode 100755 egs/librispeech/ASR/local/compile_lg.py create mode 100755 egs/librispeech/ASR/local/test_compile_lg.py create mode 100644 egs/librispeech/ASR/transducer_stateless/shallow_fusion.py diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py new file mode 100755 index 000000000..d87fdf1b9 --- /dev/null +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates LG from + + - L, the lexicon, built from lang_dir/L_disambig.pt + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G_3_gram.fst.txt + +The generated LG is saved in $lang_dir/LG.fst +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_LG(lang_dir: str) -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_500. + + Return: + An FST representing LG. + """ + + tokens = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt") + + assert "#0" in tokens + + first_token_disambig_id = tokens["#0"] + logging.info(f"first token disambig ID: {first_token_disambig_id}") + + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + if Path("data/lm/G_3_gram.pt").is_file(): + logging.info("Loading pre-compiled G_3_gram") + d = torch.load("data/lm/G_3_gram.pt") + G = k2.Fsa.from_dict(d) + else: + logging.info("Loading G_3_gram.fst.txt") + with open("data/lm/G_3_gram.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + del G.aux_labels + torch.save(G.as_dict(), "data/lm/G_3_gram.pt") + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Composing L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}, num_arcs: {LG.num_arcs}") + + del LG.aux_labels + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info( + f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}" + ) + + logging.info("Determinizing LG") + LG = k2.determinize(LG) + logging.info( + f"LG shape after k2.determinize: {LG.shape}, num_arcs: {LG.num_arcs}" + ) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + logging.info( + f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}" + ) + + logging.info("Removing disambiguation symbols on LG") + LG.labels[LG.labels >= first_token_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set LG.properties to None + LG.__dict__["_properties"] = None + + logging.info("Removing epsilons") + LG = k2.remove_epsilon(LG) + logging.info( + f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}" + ) + + logging.info("Connecting") + LG = k2.connect(LG) + logging.info( + f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}" + ) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + logging.info(f"LG properties: {LG.properties_str}") + # Possible properties is: + # "Valid|Nonempty|ArcSorted|EpsilonFree|MaybeAccessible|MaybeCoaccessible" + logging.info("Caution: LG is not deterministic!!!") + + return LG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + out_filename = lang_dir / "LG.pt" + + if out_filename.is_file(): + logging.info(f"{out_filename} already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + LG = compile_LG(lang_dir) + logging.info(f"Saving LG to {out_filename}") + torch.save(LG.as_dict(), out_filename) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/local/test_compile_lg.py b/egs/librispeech/ASR/local/test_compile_lg.py new file mode 100755 index 000000000..35eee0586 --- /dev/null +++ b/egs/librispeech/ASR/local/test_compile_lg.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./local/test_compile_lg.py +""" + +from pathlib import Path +from typing import List + +import k2 +import sentencepiece as spm +import torch + +lang_dir = Path("./data/lang_bpe_500") + + +def get_word_ids(word_table: k2.SymbolTable, s: str) -> List[int]: + """ + Args: + word_table: + Word symbol table. + s: + A string consisting of space(s) separated words. + Returns: + Return a list of word IDs. + """ + ans = [] + for w in s.split(): + ans.append(word_table[w]) + return ans + + +def main(): + assert lang_dir.exists(), f"{lang_dir} does not exist!" + LG = k2.Fsa.from_dict(torch.load(f"{lang_dir}/LG.pt", map_location="cpu")) + + sp = spm.SentencePieceProcessor() + sp.load(f"{lang_dir}/bpe.model") + + word_table = k2.SymbolTable.from_file(f"{lang_dir}/words.txt") + s = "HELLO WORLD" + token_ids = sp.encode(s) + + token_fsa = k2.linear_fsa(token_ids) + + fsa = k2.intersect(LG, token_fsa) + fsa = k2.connect(fsa) + print(k2.to_dot(fsa)) + print(fsa.properties_str) + print(LG.properties_str) + # You can use https://dreampuf.github.io/GraphvizOnline/ + # to visualize the output. + # + # You can see that the resulting fsa is not deterministic + # Note: LG is non-deterministic + # + # See https://shorturl.at/uIL69 + # for visualization of the above fsa. + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index c5efb733d..a390c3180 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -17,8 +17,10 @@ from dataclasses import dataclass from typing import Dict, List, Optional +import k2 import torch from model import Transducer +from shallow_fusion import shallow_fusion def greedy_search( @@ -111,6 +113,13 @@ class Hypothesis: # It contains only one entry. log_prob: torch.Tensor + # Used for shallow fusion + # The key of the dict is a state index into LG + # while the corresponding value is the LM score + # reaching this state. + # Note: The value tensor contains only a single entry + ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = (None,) + @property def key(self) -> str: """Return a string representation of self.ys""" @@ -149,6 +158,15 @@ class HypothesisList(object): torch.logaddexp( old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob ) + if hyp.ngram_state_and_scores is not None: + for state, score in hyp.ngram_state_and_scores.items(): + if ( + state in old_hyp.ngram_state_and_scores + and score > old_hyp.ngram_state_and_scores[state] + ): + old_hyp.ngram_state_and_scores[state] = score + else: + old_hyp.ngram_state_and_scores[state] = score else: self._data[key] = hyp @@ -318,6 +336,7 @@ def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + LG: Optional[k2.Fsa] = None, ) -> List[int]: """It limits the maximum number of symbols per frame to 1. @@ -328,9 +347,13 @@ def modified_beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + LG: + Optional. Used for shallow fusion. Returns: Return the decoded result. """ + enable_shallow_fusion = LG is not None + ngram_lm_scale = 0.8 assert encoder_out.ndim == 3 @@ -350,10 +373,19 @@ def modified_beam_search( T = encoder_out.size(1) B = HypothesisList() + + if enable_shallow_fusion: + ngram_state_and_scores = { + 0: torch.zeros(1, dtype=torch.float32, device=device) + } + else: + ngram_state_and_scores = None + B.add( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ngram_state_and_scores=ngram_state_and_scores, ) ) @@ -411,9 +443,33 @@ def modified_beam_search( new_token = topk_token_indexes[i] if new_token != blank_id: new_ys.append(new_token) + else: + ngram_state_and_scores = hyp.ngram_state_and_scores + new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + + if enable_shallow_fusion and new_token != blank_id: + ngram_state_and_scores = shallow_fusion( + LG, new_token, hyp.ngram_state_and_scores + ) + if len(ngram_state_and_scores) == 0: + continue + max_ngram_score = max(ngram_state_and_scores.values()) + new_log_prob += ngram_lm_scale * max_ngram_score + + # TODO: Get the maximum scores in ngram_state_and_scores + # and add it to new_log_prob + + new_hyp = Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + ngram_state_and_scores=ngram_state_and_scores, + ) + B.add(new_hyp) + if len(B) == 0: + for h in A: + B.add(h) best_hyp = B.get_most_probable(length_norm=True) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index c101d9397..d4da693a5 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -40,8 +40,9 @@ import argparse import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple +import k2 import sentencepiece as spm import torch import torch.nn as nn @@ -131,6 +132,13 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--LG", + type=str, + help="""Path to LG.pt for shallow fusion. + Used only when --decoding-method is modified_beam_search.""", + ) + return parser @@ -203,6 +211,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + LG: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -225,6 +234,9 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + LG: + Optional. Used for shallow fusion. Used only when params.decoding_method + is modified_beam_search. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -257,17 +269,24 @@ def decode_one_batch( ) elif params.decoding_method == "beam_search": hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, ) elif params.decoding_method == "modified_beam_search": hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + LG=LG, ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) hyps.append(sp.decode(hyp).split()) + for h in hyps: + print(" ".join(h)) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -280,6 +299,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + LG: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -292,6 +312,9 @@ def decode_dataset( The neural model. sp: The BPE model. + LG: + Optional. Used for shallow fusion. Used only when params.decoding_method + is modified_beam_search. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -320,6 +343,7 @@ def decode_dataset( model=model, sp=sp, batch=batch, + LG=LG, ) for name, hyps in hyps_dict.items(): @@ -419,6 +443,21 @@ def main(): logging.info(f"Device: {device}") + if params.LG is not None: + assert ( + params.decoding_method == "modified_beam_search" + ), "--LG is used only when --decoding_method=modified_beam_search" + logging.info(f"Loading LG from {params.LG}") + LG = k2.Fsa.from_dict(torch.load(params.LG, map_location=device)) + logging.info(f"LG properties: {LG.properties_str}") + logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}") + # If LG is created by local/compile_lg.py, then it should be epsilon + # free as well as arc sorted + assert "ArcSorted" in LG.properties_str + assert "EpsilonFree" in LG.properties_str + else: + LG = None + sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) @@ -467,6 +506,7 @@ def main(): params=params, model=model, sp=sp, + LG=LG, ) save_results( diff --git a/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py b/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py new file mode 100644 index 000000000..18ef253dd --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py @@ -0,0 +1,78 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import k2 +import torch + + +def shallow_fusion( + LG: k2.Fsa, + token: int, + state_and_scores: Dict[int, torch.Tensor], +) -> Dict[int, torch.Tensor]: + """ + Args: + LG: + An n-gram. It should be arc sorted and epsilon free. + token: + The input token ID. + state_and_scores: + The keys contain the current state we are in and the + values are the LM log_prob for reaching the corresponding + states from the start state. + Returns: + Return a new state_and_scores. + """ + row_splits = LG.arcs.row_splits(1) + arcs = LG.arcs.values() + + current_states = list(state_and_scores.keys()) + + ans = dict() + for s in current_states: + labels_begin = row_splits[s] + labels_end = row_splits[s + 1] + labels = LG.labels[labels_begin:labels_end].contiguous() + + # As LG is not deterministic, there may be multiple + # out-going arcs that with label equal to "token" + # + # Note: LG is arc sorted! + left = torch.bucketize(token, labels, right=False) + right = torch.bucketize(token, labels, right=True) + + if left >= right: + # There are no out-going arcs from this state + # that have label equal to "token" + continue + + # Now we have + # labels[i] == token + # for + # left <= i < right + + for i in range(left, right): + i += labels_begin + next_state = arcs[i][1].item() + score = LG.scores[i] + if next_state not in ans: + ans[next_state] = score + else: + ans[next_state] = max(score, ans[next_state]) + + return ans From 5af23efa6966c03e186d3e9cad1854070021e768 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 10 Feb 2022 20:28:59 +0800 Subject: [PATCH 2/3] Keep disambig tokens and backoff arcs in LG. --- egs/librispeech/ASR/local/compile_lg.py | 24 +----- egs/librispeech/ASR/local/test_compile_lg.py | 85 +++++++++++-------- .../ASR/transducer_stateless/beam_search.py | 30 +++++-- .../ASR/transducer_stateless/decode.py | 22 ++++- .../transducer_stateless/shallow_fusion.py | 75 +++++++++++----- 5 files changed, 152 insertions(+), 84 deletions(-) diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index d87fdf1b9..97301691e 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -105,31 +105,15 @@ def compile_LG(lang_dir: str) -> k2.Fsa: f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}" ) - logging.info("Removing disambiguation symbols on LG") - LG.labels[LG.labels >= first_token_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set LG.properties to None - LG.__dict__["_properties"] = None - - logging.info("Removing epsilons") - LG = k2.remove_epsilon(LG) - logging.info( - f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}" - ) - - logging.info("Connecting") - LG = k2.connect(LG) - logging.info( - f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}" - ) - logging.info("Arc sorting LG") LG = k2.arc_sort(LG) logging.info(f"LG properties: {LG.properties_str}") # Possible properties is: - # "Valid|Nonempty|ArcSorted|EpsilonFree|MaybeAccessible|MaybeCoaccessible" - logging.info("Caution: LG is not deterministic!!!") + # "Valid|Nonempty|ArcSorted|ArcSortedAndDeterministic|EpsilonFree|MaybeAccessible|MaybeCoaccessible" # noqa + logging.info( + "Caution: LG is deterministic and contains disambig symbols!!!" + ) return LG diff --git a/egs/librispeech/ASR/local/test_compile_lg.py b/egs/librispeech/ASR/local/test_compile_lg.py index 35eee0586..5131551f9 100755 --- a/egs/librispeech/ASR/local/test_compile_lg.py +++ b/egs/librispeech/ASR/local/test_compile_lg.py @@ -21,58 +21,75 @@ To run this file, do: python ./local/test_compile_lg.py """ +import os + from pathlib import Path -from typing import List import k2 -import sentencepiece as spm import torch lang_dir = Path("./data/lang_bpe_500") +corpus = "test_compile_lg_corpus.txt" +arpa = "test_compile_lg_3_gram.arpa" +G_fst_txt = "test_compile_lg_3_gram.fst.txt" -def get_word_ids(word_table: k2.SymbolTable, s: str) -> List[int]: +def generate_corpus(): + s = """HELLO WORLD +HELLOA WORLDER +HELLOA WORLDER HELLO +HELLOA WORLDER""" + with open(corpus, "w") as f: + f.write(s) + + +def generate_arpa(): + cmd = f""" + ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text {corpus} \ + -lm {arpa} """ - Args: - word_table: - Word symbol table. - s: - A string consisting of space(s) separated words. - Returns: - Return a list of word IDs. + os.system(cmd) + + +def generate_G(): + cmd = f""" + python3 -m kaldilm \ + --read-symbol-table="{lang_dir}/words.txt" \ + --disambig-symbol='#0' \ + {arpa} > {G_fst_txt} """ - ans = [] - for w in s.split(): - ans.append(word_table[w]) - return ans + os.system(cmd) def main(): - assert lang_dir.exists(), f"{lang_dir} does not exist!" - LG = k2.Fsa.from_dict(torch.load(f"{lang_dir}/LG.pt", map_location="cpu")) + generate_corpus() + generate_arpa() + generate_G() + with open(G_fst_txt) as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + del G.aux_labels + G.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt") + G.draw("G.pdf", title="G") - sp = spm.SentencePieceProcessor() - sp.load(f"{lang_dir}/bpe.model") + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt") + L.aux_labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt") - word_table = k2.SymbolTable.from_file(f"{lang_dir}/words.txt") - s = "HELLO WORLD" - token_ids = sp.encode(s) + L = k2.arc_sort(L) + G = k2.arc_sort(G) - token_fsa = k2.linear_fsa(token_ids) + LG = k2.compose(L, G) + del LG.aux_labels - fsa = k2.intersect(LG, token_fsa) - fsa = k2.connect(fsa) - print(k2.to_dot(fsa)) - print(fsa.properties_str) + LG = k2.determinize(LG) + LG = k2.connect(LG) + LG = k2.arc_sort(LG) print(LG.properties_str) - # You can use https://dreampuf.github.io/GraphvizOnline/ - # to visualize the output. - # - # You can see that the resulting fsa is not deterministic - # Note: LG is non-deterministic - # - # See https://shorturl.at/uIL69 - # for visualization of the above fsa. + LG.draw("LG.pdf", title="LG") + # You can have a look at G.pdf and LG.pdf to get a feel + # what they look like if __name__ == "__main__": diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index a390c3180..ecddbf5a9 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -155,9 +155,14 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) + + if True: + old_hyp.log_prob = torch.logaddexp( + old_hyp.log_prob, hyp.log_prob + ) + else: + old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) + if hyp.ngram_state_and_scores is not None: for state, score in hyp.ngram_state_and_scores.items(): if ( @@ -337,6 +342,7 @@ def modified_beam_search( encoder_out: torch.Tensor, beam: int = 4, LG: Optional[k2.Fsa] = None, + ngram_lm_scale: float = 0.1, ) -> List[int]: """It limits the maximum number of symbols per frame to 1. @@ -349,11 +355,13 @@ def modified_beam_search( Beam size. LG: Optional. Used for shallow fusion. + ngram_lm_scale: + Used only when LG is not None. The total score of a path is + am_score + ngram_lm_scale * ngram_lm_scale Returns: Return the decoded result. """ enable_shallow_fusion = LG is not None - ngram_lm_scale = 0.8 assert encoder_out.ndim == 3 @@ -422,6 +430,7 @@ def modified_beam_search( encoder_out_len.expand(decoder_out.size(0)), decoder_out_len.expand(decoder_out.size(0)), ) + vocab_size = logits.size(-1) # logits is of shape (num_hyps, vocab_size) log_probs = logits.log_softmax(dim=-1) @@ -437,6 +446,9 @@ def modified_beam_search( topk_hyp_indexes = topk_hyp_indexes.tolist() topk_token_indexes = topk_token_indexes.tolist() + # import pdb + # + # pdb.set_trace() for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]] new_ys = hyp.ys[:] @@ -450,12 +462,15 @@ def modified_beam_search( if enable_shallow_fusion and new_token != blank_id: ngram_state_and_scores = shallow_fusion( - LG, new_token, hyp.ngram_state_and_scores + LG, + new_token, + hyp.ngram_state_and_scores, + vocab_size, ) if len(ngram_state_and_scores) == 0: continue max_ngram_score = max(ngram_state_and_scores.values()) - new_log_prob += ngram_lm_scale * max_ngram_score + new_log_prob = new_log_prob + ngram_lm_scale * max_ngram_score # TODO: Get the maximum scores in ngram_state_and_scores # and add it to new_log_prob @@ -468,6 +483,9 @@ def modified_beam_search( B.add(new_hyp) if len(B) == 0: + import logging + + logging.info("\n*****\nEmpty states!\n***\n") for h in A: B.add(h) diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index d4da693a5..b70b97d70 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -139,6 +139,15 @@ def get_parser(): Used only when --decoding-method is modified_beam_search.""", ) + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.1, + help="""Used when only --LG is provided. + The total score of a path is am_score + ngram_lm_scale * ngram_lm_score. + """, + ) + return parser @@ -279,14 +288,18 @@ def decode_one_batch( encoder_out=encoder_out_i, beam=params.beam_size, LG=LG, + ngram_lm_scale=params.ngram_lm_scale, ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) hyps.append(sp.decode(hyp).split()) + s = "\n" for h in hyps: - print(" ".join(h)) + s += " ".join(h) + s += "\n" + logging.info(s) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -336,6 +349,8 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): + if batch_idx > 10: + break texts = batch["supervisions"]["text"] hyps_dict = decode_one_batch( @@ -452,9 +467,10 @@ def main(): logging.info(f"LG properties: {LG.properties_str}") logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}") # If LG is created by local/compile_lg.py, then it should be epsilon - # free as well as arc sorted + # free, deterministic, and arc sorted assert "ArcSorted" in LG.properties_str assert "EpsilonFree" in LG.properties_str + assert "Deterministic" in LG.properties_str else: LG = None @@ -501,6 +517,8 @@ def main(): test_dl = [test_clean_dl, test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): + if test_set == "test-other": + break results_dict = decode_dataset( dl=test_dl, params=params, diff --git a/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py b/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py index 18ef253dd..8f1045d45 100644 --- a/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py +++ b/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py @@ -18,61 +18,92 @@ from typing import Dict import k2 import torch +import copy def shallow_fusion( LG: k2.Fsa, token: int, state_and_scores: Dict[int, torch.Tensor], + vocab_size: int, ) -> Dict[int, torch.Tensor]: """ Args: LG: - An n-gram. It should be arc sorted and epsilon free. + An n-gram. It should be arc sorted, deterministic, and epsilon free. token: The input token ID. state_and_scores: The keys contain the current state we are in and the values are the LM log_prob for reaching the corresponding states from the start state. + vocab_size: + Vocabulary size, including the blank symbol. We assume that + token IDs >= vocab_size are disambig IDs (including the backoff + symbol #0). Returns: Return a new state_and_scores. """ row_splits = LG.arcs.row_splits(1) arcs = LG.arcs.values() + state_and_scores = copy.deepcopy(state_and_scores) + current_states = list(state_and_scores.keys()) + # Process out-going arcs with label being disambig tokens and #0 + while len(current_states) > 0: + s = current_states.pop() + labels_begin = row_splits[s] + labels_end = row_splits[s + 1] + labels = LG.labels[labels_begin:labels_end].contiguous() + + for i in reversed(range(labels.numel())): + lab = labels[i] + if lab == -1: + # Note: When sorting arcs, k2 treats arc labels as + # unsigned types + continue + + if lab < vocab_size: + # Since LG is arc sorted, we can exit + # the for loop as soon as we have a label + # with ID less than vocab_size + break + + # This is a diambig token or #0 + idx = labels_begin + i + next_state = arcs[idx][1].item() + score = LG.scores[idx] + state_and_scores[s] + if next_state not in state_and_scores: + state_and_scores[next_state] = score + current_states.append(next_state) + else: + state_and_scores[next_state] = max( + score, state_and_scores[next_state] + ) + + current_states = list(state_and_scores.keys()) ans = dict() for s in current_states: labels_begin = row_splits[s] labels_end = row_splits[s + 1] labels = LG.labels[labels_begin:labels_end].contiguous() - # As LG is not deterministic, there may be multiple - # out-going arcs that with label equal to "token" - # - # Note: LG is arc sorted! - left = torch.bucketize(token, labels, right=False) - right = torch.bucketize(token, labels, right=True) + if labels[-1] == -1: + labels = labels[:-1] - if left >= right: - # There are no out-going arcs from this state - # that have label equal to "token" + pos = torch.searchsorted(labels, token) + if pos >= labels.numel() or labels[pos] != token: continue - # Now we have - # labels[i] == token - # for - # left <= i < right + idx = labels_begin + pos + next_state = arcs[idx][1].item() + score = LG.scores[idx] + state_and_scores[s] - for i in range(left, right): - i += labels_begin - next_state = arcs[i][1].item() - score = LG.scores[i] - if next_state not in ans: - ans[next_state] = score - else: - ans[next_state] = max(score, ans[next_state]) + if next_state not in ans: + ans[next_state] = score + else: + ans[next_state] = max(score, ans[next_state]) return ans From adb54aea91abe211b19ec75eeb422b15a3867405 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 15 Feb 2022 12:33:53 +0800 Subject: [PATCH 3/3] Add backoff arcs to the start state to handle OOV word. --- .../ASR/transducer_stateless/beam_search.py | 343 ++++++++---------- .../ASR/transducer_stateless/decode.py | 40 +- .../transducer_stateless/shallow_fusion.py | 90 +++-- .../ASR/transducer_stateless/utils.py | 219 +++++++++++ 4 files changed, 459 insertions(+), 233 deletions(-) create mode 100644 egs/librispeech/ASR/transducer_stateless/utils.py diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index ecddbf5a9..088885f73 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Dict, List, Optional import k2 import torch from model import Transducer from shallow_fusion import shallow_fusion +from utils import Hypothesis, HypothesisList def greedy_search( @@ -103,153 +103,6 @@ def greedy_search( return hyp -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys. - # It contains only one entry. - log_prob: torch.Tensor - - # Used for shallow fusion - # The key of the dict is a state index into LG - # while the corresponding value is the LM score - # reaching this state. - # Note: The value tensor contains only a single entry - ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = (None,) - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self) -> Dict[str, Hypothesis]: - return self._data - - def add(self, hyp: Hypothesis) -> None: - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] # shallow copy - - if True: - old_hyp.log_prob = torch.logaddexp( - old_hyp.log_prob, hyp.log_prob - ) - else: - old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) - - if hyp.ngram_state_and_scores is not None: - for state, score in hyp.ngram_state_and_scores.items(): - if ( - state in old_hyp.ngram_state_and_scores - and score > old_hyp.ngram_state_and_scores[state] - ): - old_hyp.ngram_state_and_scores[state] = score - else: - old_hyp.ngram_state_and_scores[state] = score - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - Returns: - Return the hypothesis that has the largest `log_prob`. - """ - if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" - hyps = list(self._data.items()) - - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - def run_decoder( ys: List[int], model: Transducer, @@ -341,6 +194,113 @@ def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :] + # current_encoder_out is of shape (1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_output is of shape (num_hyps, 1, decoder_output_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, -1 + ) + + logits = model.joiner( + current_encoder_out, + decoder_out, + encoder_out_len.expand(decoder_out.size(0)), + decoder_out_len.expand(decoder_out.size(0)), + ) + # logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[i] + if new_token != blank_id: + new_ys.append(new_token) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + +def modified_beam_search_with_shallow_fusion( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, LG: Optional[k2.Fsa] = None, ngram_lm_scale: float = 0.1, ) -> List[int]: @@ -408,7 +368,14 @@ def modified_beam_search( A = list(B) B = HypothesisList() - ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs contains both AM scores and LM scores + ys_log_probs = torch.cat( + [ + hyp.log_prob.reshape(1, 1) + + ngram_lm_scale * max(hyp.ngram_state_and_scores.values()) + for hyp in A + ] + ) # ys_log_probs is of shape (num_hyps, 1) decoder_input = torch.tensor( @@ -434,62 +401,52 @@ def modified_beam_search( # logits is of shape (num_hyps, vocab_size) log_probs = logits.log_softmax(dim=-1) - log_probs.add_(ys_log_probs) + tot_log_probs = log_probs + ys_log_probs - log_probs = log_probs.reshape(-1) - topk_log_probs, topk_indexes = log_probs.topk(beam) + _, topk_indexes = tot_log_probs.reshape(-1).topk(beam) + topk_log_probs = log_probs.reshape(-1)[topk_indexes] # topk_hyp_indexes are indexes into `A` topk_hyp_indexes = topk_indexes // logits.size(-1) topk_token_indexes = topk_indexes % logits.size(-1) - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() + topk_hyp_indexes, indexes = torch.sort(topk_hyp_indexes) + topk_token_indexes = topk_token_indexes[indexes] + topk_log_probs = topk_log_probs[indexes] - # import pdb - # - # pdb.set_trace() - for i in range(len(topk_hyp_indexes)): - hyp = A[topk_hyp_indexes[i]] - new_ys = hyp.ys[:] - new_token = topk_token_indexes[i] - if new_token != blank_id: - new_ys.append(new_token) - else: - ngram_state_and_scores = hyp.ngram_state_and_scores + shape = k2.ragged.create_ragged_shape2( + row_ids=topk_hyp_indexes.to(torch.int32), + cached_tot_size=topk_hyp_indexes.numel(), + ) + blank_log_probs = log_probs[topk_hyp_indexes, 0] - new_log_prob = topk_log_probs[i] + row_splits = shape.row_splits(1).tolist() + num_rows = len(row_splits) - 1 + for i in range(num_rows): + start = row_splits[i] + end = row_splits[i + 1] + if start >= end: + # Discard A[i] as other hyps have higher log_probs + continue + tokens = topk_token_indexes[start:end] - if enable_shallow_fusion and new_token != blank_id: - ngram_state_and_scores = shallow_fusion( - LG, - new_token, - hyp.ngram_state_and_scores, - vocab_size, - ) - if len(ngram_state_and_scores) == 0: - continue - max_ngram_score = max(ngram_state_and_scores.values()) - new_log_prob = new_log_prob + ngram_lm_scale * max_ngram_score - - # TODO: Get the maximum scores in ngram_state_and_scores - # and add it to new_log_prob - - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - ngram_state_and_scores=ngram_state_and_scores, + hyps = shallow_fusion( + LG, + A[i], + tokens, + topk_log_probs[start:end], + vocab_size, + blank_log_probs[i], ) - - B.add(new_hyp) - if len(B) == 0: - import logging - - logging.info("\n*****\nEmpty states!\n***\n") - for h in A: + for h in hyps: B.add(h) - best_hyp = B.get_most_probable(length_norm=True) + if len(B) > beam: + B = B.topk(beam, ngram_lm_scale=ngram_lm_scale) + + best_hyp = B.get_most_probable( + length_norm=True, ngram_lm_scale=ngram_lm_scale + ) ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks return ys diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index b70b97d70..abd5e7fe9 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -47,7 +47,12 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search, modified_beam_search +from beam_search import ( + beam_search, + greedy_search, + modified_beam_search, + modified_beam_search_with_shallow_fusion, +) from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -283,23 +288,25 @@ def decode_one_batch( beam=params.beam_size, ) elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - LG=LG, - ngram_lm_scale=params.ngram_lm_scale, - ) + if LG is None: + hyp = modified_beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + hyp = modified_beam_search_with_shallow_fusion( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + LG=LG, + ngram_lm_scale=params.ngram_lm_scale, + ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) hyps.append(sp.decode(hyp).split()) - s = "\n" - for h in hyps: - s += " ".join(h) - s += "\n" - logging.info(s) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -349,8 +356,6 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): - if batch_idx > 10: - break texts = batch["supervisions"]["text"] hyps_dict = decode_one_batch( @@ -464,6 +469,9 @@ def main(): ), "--LG is used only when --decoding_method=modified_beam_search" logging.info(f"Loading LG from {params.LG}") LG = k2.Fsa.from_dict(torch.load(params.LG, map_location=device)) + logging.info( + f"max: {LG.scores.max()}, min: {LG.scores.min()}, mean: {LG.scores.mean()}" + ) logging.info(f"LG properties: {LG.properties_str}") logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}") # If LG is created by local/compile_lg.py, then it should be epsilon @@ -517,8 +525,6 @@ def main(): test_dl = [test_clean_dl, test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): - if test_set == "test-other": - break results_dict = decode_dataset( dl=test_dl, params=params, diff --git a/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py b/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py index 8f1045d45..e74f9e350 100644 --- a/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py +++ b/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py @@ -19,39 +19,51 @@ from typing import Dict import k2 import torch import copy +from utils import Hypothesis, HypothesisList def shallow_fusion( LG: k2.Fsa, - token: int, - state_and_scores: Dict[int, torch.Tensor], + hyp: Hypothesis, + tokens: torch.Tensor, + log_probs: torch.Tensor, vocab_size: int, -) -> Dict[int, torch.Tensor]: + blank_log_prob: torch.Tensor, +) -> HypothesisList: """ Args: LG: An n-gram. It should be arc sorted, deterministic, and epsilon free. - token: - The input token ID. - state_and_scores: - The keys contain the current state we are in and the - values are the LM log_prob for reaching the corresponding - states from the start state. + It contains disambig IDs and back-off arcs. + hyp: + The current hypothesis. + tokens: + The possible tokens that will be expanded from the given `hyp`. + It is a 1-D tensor of dtype torch.int32. + log_probs: + It contains the acoustic log probabilities of each path that + is extended from `hyp.ys` with `tokens`. + log_probs.shape == tokens.shape. vocab_size: Vocabulary size, including the blank symbol. We assume that token IDs >= vocab_size are disambig IDs (including the backoff symbol #0). + blank_log_prob: + The log_prob for the blank token at this frame. It is from + the output of the joiner. Returns: - Return a new state_and_scores. + Return new hypotheses by extending the given `hyp` with tokens in the + given `tokens`. """ + row_splits = LG.arcs.row_splits(1) arcs = LG.arcs.values() - state_and_scores = copy.deepcopy(state_and_scores) + state_and_scores = copy.deepcopy(hyp.ngram_state_and_scores) current_states = list(state_and_scores.keys()) - # Process out-going arcs with label being disambig tokens and #0 + # Process out-going arcs with label equal to disambig tokens or #0 while len(current_states) > 0: s = current_states.pop() labels_begin = row_splits[s] @@ -84,7 +96,9 @@ def shallow_fusion( ) current_states = list(state_and_scores.keys()) - ans = dict() + ans = HypothesisList() + + device = log_probs.device for s in current_states: labels_begin = row_splits[s] labels_end = row_splits[s + 1] @@ -93,17 +107,47 @@ def shallow_fusion( if labels[-1] == -1: labels = labels[:-1] - pos = torch.searchsorted(labels, token) - if pos >= labels.numel() or labels[pos] != token: - continue + if s != 0: + # We add a backoff arc to the start state. Otherwise, + # all activate state may die due to out-of-Vocabulary word. + new_hyp = Hypothesis( + ys=hyp.ys[:], + log_prob=hyp.log_prob + blank_log_prob, + ngram_state_and_scores={ + # -20 is the cost on the backoff arc to the start state. + # As LG.scores.min() is about -16.6, we choose -20 here. + # You may need to tune this value. + 0: torch.full((1,), -20, dtype=torch.float32, device=device) + }, + ) + ans.add(new_hyp) - idx = labels_begin + pos - next_state = arcs[idx][1].item() - score = LG.scores[idx] + state_and_scores[s] + pos = torch.searchsorted(labels, tokens) + for i in range(pos.numel()): + if tokens[i] == 0: + # blank ID + new_hyp = Hypothesis( + ys=hyp.ys[:], + log_prob=hyp.log_prob + log_probs[i], + ngram_state_and_scores=hyp.ngram_state_and_scores, + ) + ans.add(new_hyp) + continue + elif pos[i] >= labels.numel() or labels[pos[i]] != tokens[i]: + # No out-going arcs from this state has labels + # equal to tokens[i] + continue - if next_state not in ans: - ans[next_state] = score - else: - ans[next_state] = max(score, ans[next_state]) + # Found one arc + + idx = labels_begin + pos[i] + next_state = arcs[idx][1].item() + score = LG.scores[idx] + state_and_scores[s] + new_hyp = Hypothesis( + ys=hyp.ys + [tokens[i].item()], + log_prob=hyp.log_prob + log_probs[i], + ngram_state_and_scores={next_state: score}, + ) + ans.add(new_hyp) return ans diff --git a/egs/librispeech/ASR/transducer_stateless/utils.py b/egs/librispeech/ASR/transducer_stateless/utils.py new file mode 100644 index 000000000..97f6a740d --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/utils.py @@ -0,0 +1,219 @@ +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] + + # The log prob of ys. + # It contains only one entry. + # Note: It contains only the acoustic part. + log_prob: torch.Tensor + + # Used for shallow fusion + # The key of the dict is a state index into LG + # while the corresponding value is the LM score + # reaching this state from the start state. + # Note: The value tensor contains only a single entry + # and it contains only the LM part. + ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = None + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[str, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + + if False: + old_hyp.log_prob = torch.logaddexp( + old_hyp.log_prob, hyp.log_prob + ) + else: + old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) + + if hyp.ngram_state_and_scores is not None: + for state, score in hyp.ngram_state_and_scores.items(): + if ( + state in old_hyp.ngram_state_and_scores + and score > old_hyp.ngram_state_and_scores[state] + ): + old_hyp.ngram_state_and_scores[state] = score + else: + old_hyp.ngram_state_and_scores[state] = score + else: + self._data[key] = hyp + + def get_most_probable( + self, length_norm: bool = False, ngram_lm_scale: Optional[float] = None + ) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + ngram_lm_scale: + If not None, it specifies the scale applied to the LM score. + Returns: + Return the hypothesis that has the largest `log_prob`. + """ + if length_norm: + if ngram_lm_scale is None: + return max( + self._data.values(), + key=lambda hyp: hyp.log_prob / len(hyp.ys), + ) + else: + return max( + self._data.values(), + key=lambda hyp: ( + hyp.log_prob + + ngram_lm_scale + * max(hyp.ngram_state_and_scores.values()) + ) + / len(hyp.ys), + ) + else: + if ngram_lm_scale is None: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + else: + return max( + self._data.values(), + key=lambda hyp: hyp.log_prob + + ngram_lm_scale * max(hyp.ngram_state_and_scores.values()), + ) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter( + self, threshold: torch.Tensor, ngram_lm_scale: Optional[float] = None + ) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Args: + threshold: + Hypotheses with log_prob less than this value are removed. + ngram_lm_scale: + If not None, it specifies the scale applied to the LM score. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + if ngram_lm_scale is None: + for _, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + else: + for _, hyp in self._data.items(): + if ( + hyp.log_prob + + ngram_lm_scale * max(hyp.ngram_state_and_scores.values()) + > threshold + ): + ans.add(hyp) # shallow copy + return ans + + def topk( + self, k: int, ngram_lm_scale: Optional[float] = None + ) -> "HypothesisList": + """Return the top-k hypothesis.""" + hyps = list(self._data.items()) + + if ngram_lm_scale is None: + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + else: + hyps = sorted( + hyps, + key=lambda h: h[1].log_prob + + ngram_lm_scale * max(h[1].ngram_state_and_scores.values()), + reverse=True, + )[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s)