diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py new file mode 100755 index 000000000..97301691e --- /dev/null +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates LG from + + - L, the lexicon, built from lang_dir/L_disambig.pt + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G_3_gram.fst.txt + +The generated LG is saved in $lang_dir/LG.fst +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_LG(lang_dir: str) -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_500. + + Return: + An FST representing LG. + """ + + tokens = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt") + + assert "#0" in tokens + + first_token_disambig_id = tokens["#0"] + logging.info(f"first token disambig ID: {first_token_disambig_id}") + + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + if Path("data/lm/G_3_gram.pt").is_file(): + logging.info("Loading pre-compiled G_3_gram") + d = torch.load("data/lm/G_3_gram.pt") + G = k2.Fsa.from_dict(d) + else: + logging.info("Loading G_3_gram.fst.txt") + with open("data/lm/G_3_gram.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + del G.aux_labels + torch.save(G.as_dict(), "data/lm/G_3_gram.pt") + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Composing L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}, num_arcs: {LG.num_arcs}") + + del LG.aux_labels + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info( + f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}" + ) + + logging.info("Determinizing LG") + LG = k2.determinize(LG) + logging.info( + f"LG shape after k2.determinize: {LG.shape}, num_arcs: {LG.num_arcs}" + ) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + logging.info( + f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}" + ) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + logging.info(f"LG properties: {LG.properties_str}") + # Possible properties is: + # "Valid|Nonempty|ArcSorted|ArcSortedAndDeterministic|EpsilonFree|MaybeAccessible|MaybeCoaccessible" # noqa + logging.info( + "Caution: LG is deterministic and contains disambig symbols!!!" + ) + + return LG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + out_filename = lang_dir / "LG.pt" + + if out_filename.is_file(): + logging.info(f"{out_filename} already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + LG = compile_LG(lang_dir) + logging.info(f"Saving LG to {out_filename}") + torch.save(LG.as_dict(), out_filename) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/local/test_compile_lg.py b/egs/librispeech/ASR/local/test_compile_lg.py new file mode 100755 index 000000000..5131551f9 --- /dev/null +++ b/egs/librispeech/ASR/local/test_compile_lg.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./local/test_compile_lg.py +""" + +import os + +from pathlib import Path + +import k2 +import torch + +lang_dir = Path("./data/lang_bpe_500") +corpus = "test_compile_lg_corpus.txt" +arpa = "test_compile_lg_3_gram.arpa" +G_fst_txt = "test_compile_lg_3_gram.fst.txt" + + +def generate_corpus(): + s = """HELLO WORLD +HELLOA WORLDER +HELLOA WORLDER HELLO +HELLOA WORLDER""" + with open(corpus, "w") as f: + f.write(s) + + +def generate_arpa(): + cmd = f""" + ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text {corpus} \ + -lm {arpa} + """ + os.system(cmd) + + +def generate_G(): + cmd = f""" + python3 -m kaldilm \ + --read-symbol-table="{lang_dir}/words.txt" \ + --disambig-symbol='#0' \ + {arpa} > {G_fst_txt} + """ + os.system(cmd) + + +def main(): + generate_corpus() + generate_arpa() + generate_G() + with open(G_fst_txt) as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + del G.aux_labels + G.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt") + G.draw("G.pdf", title="G") + + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + L.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt") + L.aux_labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt") + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + LG = k2.compose(L, G) + del LG.aux_labels + + LG = k2.determinize(LG) + LG = k2.connect(LG) + LG = k2.arc_sort(LG) + print(LG.properties_str) + LG.draw("LG.pdf", title="LG") + # You can have a look at G.pdf and LG.pdf to get a feel + # what they look like + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index c5efb733d..088885f73 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -14,11 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Dict, List, Optional +import k2 import torch from model import Transducer +from shallow_fusion import shallow_fusion +from utils import Hypothesis, HypothesisList def greedy_search( @@ -101,132 +103,6 @@ def greedy_search( return hyp -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys. - # It contains only one entry. - log_prob: torch.Tensor - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self) -> Dict[str, Hypothesis]: - return self._data - - def add(self, hyp: Hypothesis) -> None: - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] # shallow copy - torch.logaddexp( - old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob - ) - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - Returns: - Return the hypothesis that has the largest `log_prob`. - """ - if length_norm: - return max( - self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) - ) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" - hyps = list(self._data.items()) - - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - def run_decoder( ys: List[int], model: Transducer, @@ -421,6 +297,161 @@ def modified_beam_search( return ys +def modified_beam_search_with_shallow_fusion( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, + LG: Optional[k2.Fsa] = None, + ngram_lm_scale: float = 0.1, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + LG: + Optional. Used for shallow fusion. + ngram_lm_scale: + Used only when LG is not None. The total score of a path is + am_score + ngram_lm_scale * ngram_lm_scale + Returns: + Return the decoded result. + """ + enable_shallow_fusion = LG is not None + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + + T = encoder_out.size(1) + + B = HypothesisList() + + if enable_shallow_fusion: + ngram_state_and_scores = { + 0: torch.zeros(1, dtype=torch.float32, device=device) + } + else: + ngram_state_and_scores = None + + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ngram_state_and_scores=ngram_state_and_scores, + ) + ) + + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :] + # current_encoder_out is of shape (1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + # ys_log_probs contains both AM scores and LM scores + ys_log_probs = torch.cat( + [ + hyp.log_prob.reshape(1, 1) + + ngram_lm_scale * max(hyp.ngram_state_and_scores.values()) + for hyp in A + ] + ) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_output is of shape (num_hyps, 1, decoder_output_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, -1 + ) + + logits = model.joiner( + current_encoder_out, + decoder_out, + encoder_out_len.expand(decoder_out.size(0)), + decoder_out_len.expand(decoder_out.size(0)), + ) + vocab_size = logits.size(-1) + # logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + tot_log_probs = log_probs + ys_log_probs + + _, topk_indexes = tot_log_probs.reshape(-1).topk(beam) + topk_log_probs = log_probs.reshape(-1)[topk_indexes] + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + topk_hyp_indexes, indexes = torch.sort(topk_hyp_indexes) + topk_token_indexes = topk_token_indexes[indexes] + topk_log_probs = topk_log_probs[indexes] + + shape = k2.ragged.create_ragged_shape2( + row_ids=topk_hyp_indexes.to(torch.int32), + cached_tot_size=topk_hyp_indexes.numel(), + ) + blank_log_probs = log_probs[topk_hyp_indexes, 0] + + row_splits = shape.row_splits(1).tolist() + num_rows = len(row_splits) - 1 + for i in range(num_rows): + start = row_splits[i] + end = row_splits[i + 1] + if start >= end: + # Discard A[i] as other hyps have higher log_probs + continue + tokens = topk_token_indexes[start:end] + + hyps = shallow_fusion( + LG, + A[i], + tokens, + topk_log_probs[start:end], + vocab_size, + blank_log_probs[i], + ) + for h in hyps: + B.add(h) + + if len(B) > beam: + B = B.topk(beam, ngram_lm_scale=ngram_lm_scale) + + best_hyp = B.get_most_probable( + length_norm=True, ngram_lm_scale=ngram_lm_scale + ) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + def beam_search( model: Transducer, encoder_out: torch.Tensor, diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index f23a3a300..61d990b69 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -49,13 +49,19 @@ import argparse import logging from collections import defaultdict from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple +import k2 import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search, modified_beam_search +from beam_search import ( + beam_search, + greedy_search, + modified_beam_search, + modified_beam_search_with_shallow_fusion, +) from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -140,6 +146,22 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--LG", + type=str, + help="""Path to LG.pt for shallow fusion. + Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.1, + help="""Used when only --LG is provided. + The total score of a path is am_score + ngram_lm_scale * ngram_lm_score. + """, + ) + return parser @@ -212,6 +234,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + LG: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -234,6 +257,9 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + LG: + Optional. Used for shallow fusion. Used only when params.decoding_method + is modified_beam_search. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -266,12 +292,25 @@ def decode_one_batch( ) elif params.decoding_method == "beam_search": hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, ) elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) + if LG is None: + hyp = modified_beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + hyp = modified_beam_search_with_shallow_fusion( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + LG=LG, + ngram_lm_scale=params.ngram_lm_scale, + ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -289,6 +328,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + LG: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -301,6 +341,9 @@ def decode_dataset( The neural model. sp: The BPE model. + LG: + Optional. Used for shallow fusion. Used only when params.decoding_method + is modified_beam_search. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -329,6 +372,7 @@ def decode_dataset( model=model, sp=sp, batch=batch, + LG=LG, ) for name, hyps in hyps_dict.items(): @@ -428,6 +472,25 @@ def main(): logging.info(f"Device: {device}") + if params.LG is not None: + assert ( + params.decoding_method == "modified_beam_search" + ), "--LG is used only when --decoding_method=modified_beam_search" + logging.info(f"Loading LG from {params.LG}") + LG = k2.Fsa.from_dict(torch.load(params.LG, map_location=device)) + logging.info( + f"max: {LG.scores.max()}, min: {LG.scores.min()}, mean: {LG.scores.mean()}" + ) + logging.info(f"LG properties: {LG.properties_str}") + logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}") + # If LG is created by local/compile_lg.py, then it should be epsilon + # free, deterministic, and arc sorted + assert "ArcSorted" in LG.properties_str + assert "EpsilonFree" in LG.properties_str + assert "Deterministic" in LG.properties_str + else: + LG = None + sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) @@ -476,6 +539,7 @@ def main(): params=params, model=model, sp=sp, + LG=LG, ) save_results( diff --git a/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py b/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py new file mode 100644 index 000000000..e74f9e350 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/shallow_fusion.py @@ -0,0 +1,153 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import k2 +import torch +import copy +from utils import Hypothesis, HypothesisList + + +def shallow_fusion( + LG: k2.Fsa, + hyp: Hypothesis, + tokens: torch.Tensor, + log_probs: torch.Tensor, + vocab_size: int, + blank_log_prob: torch.Tensor, +) -> HypothesisList: + """ + Args: + LG: + An n-gram. It should be arc sorted, deterministic, and epsilon free. + It contains disambig IDs and back-off arcs. + hyp: + The current hypothesis. + tokens: + The possible tokens that will be expanded from the given `hyp`. + It is a 1-D tensor of dtype torch.int32. + log_probs: + It contains the acoustic log probabilities of each path that + is extended from `hyp.ys` with `tokens`. + log_probs.shape == tokens.shape. + vocab_size: + Vocabulary size, including the blank symbol. We assume that + token IDs >= vocab_size are disambig IDs (including the backoff + symbol #0). + blank_log_prob: + The log_prob for the blank token at this frame. It is from + the output of the joiner. + Returns: + Return new hypotheses by extending the given `hyp` with tokens in the + given `tokens`. + """ + + row_splits = LG.arcs.row_splits(1) + arcs = LG.arcs.values() + + state_and_scores = copy.deepcopy(hyp.ngram_state_and_scores) + + current_states = list(state_and_scores.keys()) + + # Process out-going arcs with label equal to disambig tokens or #0 + while len(current_states) > 0: + s = current_states.pop() + labels_begin = row_splits[s] + labels_end = row_splits[s + 1] + labels = LG.labels[labels_begin:labels_end].contiguous() + + for i in reversed(range(labels.numel())): + lab = labels[i] + if lab == -1: + # Note: When sorting arcs, k2 treats arc labels as + # unsigned types + continue + + if lab < vocab_size: + # Since LG is arc sorted, we can exit + # the for loop as soon as we have a label + # with ID less than vocab_size + break + + # This is a diambig token or #0 + idx = labels_begin + i + next_state = arcs[idx][1].item() + score = LG.scores[idx] + state_and_scores[s] + if next_state not in state_and_scores: + state_and_scores[next_state] = score + current_states.append(next_state) + else: + state_and_scores[next_state] = max( + score, state_and_scores[next_state] + ) + + current_states = list(state_and_scores.keys()) + ans = HypothesisList() + + device = log_probs.device + for s in current_states: + labels_begin = row_splits[s] + labels_end = row_splits[s + 1] + labels = LG.labels[labels_begin:labels_end].contiguous() + + if labels[-1] == -1: + labels = labels[:-1] + + if s != 0: + # We add a backoff arc to the start state. Otherwise, + # all activate state may die due to out-of-Vocabulary word. + new_hyp = Hypothesis( + ys=hyp.ys[:], + log_prob=hyp.log_prob + blank_log_prob, + ngram_state_and_scores={ + # -20 is the cost on the backoff arc to the start state. + # As LG.scores.min() is about -16.6, we choose -20 here. + # You may need to tune this value. + 0: torch.full((1,), -20, dtype=torch.float32, device=device) + }, + ) + ans.add(new_hyp) + + pos = torch.searchsorted(labels, tokens) + for i in range(pos.numel()): + if tokens[i] == 0: + # blank ID + new_hyp = Hypothesis( + ys=hyp.ys[:], + log_prob=hyp.log_prob + log_probs[i], + ngram_state_and_scores=hyp.ngram_state_and_scores, + ) + ans.add(new_hyp) + continue + elif pos[i] >= labels.numel() or labels[pos[i]] != tokens[i]: + # No out-going arcs from this state has labels + # equal to tokens[i] + continue + + # Found one arc + + idx = labels_begin + pos[i] + next_state = arcs[idx][1].item() + score = LG.scores[idx] + state_and_scores[s] + new_hyp = Hypothesis( + ys=hyp.ys + [tokens[i].item()], + log_prob=hyp.log_prob + log_probs[i], + ngram_state_and_scores={next_state: score}, + ) + ans.add(new_hyp) + + return ans diff --git a/egs/librispeech/ASR/transducer_stateless/utils.py b/egs/librispeech/ASR/transducer_stateless/utils.py new file mode 100644 index 000000000..97f6a740d --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless/utils.py @@ -0,0 +1,219 @@ +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] + + # The log prob of ys. + # It contains only one entry. + # Note: It contains only the acoustic part. + log_prob: torch.Tensor + + # Used for shallow fusion + # The key of the dict is a state index into LG + # while the corresponding value is the LM score + # reaching this state from the start state. + # Note: The value tensor contains only a single entry + # and it contains only the LM part. + ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = None + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[str, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + + if False: + old_hyp.log_prob = torch.logaddexp( + old_hyp.log_prob, hyp.log_prob + ) + else: + old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) + + if hyp.ngram_state_and_scores is not None: + for state, score in hyp.ngram_state_and_scores.items(): + if ( + state in old_hyp.ngram_state_and_scores + and score > old_hyp.ngram_state_and_scores[state] + ): + old_hyp.ngram_state_and_scores[state] = score + else: + old_hyp.ngram_state_and_scores[state] = score + else: + self._data[key] = hyp + + def get_most_probable( + self, length_norm: bool = False, ngram_lm_scale: Optional[float] = None + ) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + ngram_lm_scale: + If not None, it specifies the scale applied to the LM score. + Returns: + Return the hypothesis that has the largest `log_prob`. + """ + if length_norm: + if ngram_lm_scale is None: + return max( + self._data.values(), + key=lambda hyp: hyp.log_prob / len(hyp.ys), + ) + else: + return max( + self._data.values(), + key=lambda hyp: ( + hyp.log_prob + + ngram_lm_scale + * max(hyp.ngram_state_and_scores.values()) + ) + / len(hyp.ys), + ) + else: + if ngram_lm_scale is None: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + else: + return max( + self._data.values(), + key=lambda hyp: hyp.log_prob + + ngram_lm_scale * max(hyp.ngram_state_and_scores.values()), + ) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter( + self, threshold: torch.Tensor, ngram_lm_scale: Optional[float] = None + ) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Args: + threshold: + Hypotheses with log_prob less than this value are removed. + ngram_lm_scale: + If not None, it specifies the scale applied to the LM score. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + if ngram_lm_scale is None: + for _, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + else: + for _, hyp in self._data.items(): + if ( + hyp.log_prob + + ngram_lm_scale * max(hyp.ngram_state_and_scores.values()) + > threshold + ): + ans.add(hyp) # shallow copy + return ans + + def topk( + self, k: int, ngram_lm_scale: Optional[float] = None + ) -> "HypothesisList": + """Return the top-k hypothesis.""" + hyps = list(self._data.items()) + + if ngram_lm_scale is None: + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + else: + hyps = sorted( + hyps, + key=lambda h: h[1].log_prob + + ngram_lm_scale * max(h[1].ngram_state_and_scores.values()), + reverse=True, + )[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s)