diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index bd2d6e258..51a321572 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -24,7 +24,7 @@ import sentencepiece as spm import torch from model import Transducer -from icefall import NgramLm, NgramLmStateCost +from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding from icefall.lm_wrapper import LmScorer from icefall.rnn_lm.model import RnnLmModel @@ -742,6 +742,9 @@ class Hypothesis: # N-gram LM state state_cost: Optional[NgramLmStateCost] = None + # Context graph state + context_state: Optional[ContextState] = None + @property def key(self) -> str: """Return a string representation of self.ys""" @@ -883,6 +886,7 @@ def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, + context_graph: Optional[ContextGraph] = None, beam: int = 4, temperature: float = 1.0, return_timestamps: bool = False, @@ -934,6 +938,7 @@ def modified_beam_search( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), + context_state=ContextState(state_id=0), timestamp=[], ) ) @@ -1017,17 +1022,53 @@ def modified_beam_search( new_ys = hyp.ys[:] new_token = topk_token_indexes[k] new_timestamp = hyp.timestamp[:] + new_context_state = None if new_token not in (blank_id, unk_id): new_ys.append(new_token) new_timestamp.append(t) - - new_log_prob = topk_log_probs[k] + if context_graph is not None: + new_context_state = context_graph.get_next_state( + hyp.context_state.state_id, new_token + ) + new_log_prob = topk_log_probs[k] + ( + 0 + if new_context_state is None + else new_context_state.score + ) new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + context_state=hyp.context_state + if new_context_state is None + else new_context_state, ) B[i].add(new_hyp) B = B + finalized_B + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + finalized_B = [HypothesisList() for _ in range(len(B))] + for i, hyps in enumerate(B): + for hyp in list(hyps): + if hyp.context_state.state_id != 0: + new_context_state = context_graph.get_next_state( + hyp.context_state.state_id, 0 + ) + finalized_B[i].add( + Hypothesis( + ys=hyp.ys, + log_prob=hyp.log_prob + new_context_state.score, + timestamp=hyp.timestamp, + context_state=new_context_state, + ) + ) + else: + finalized_B[i].add(hyp) + B = finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] sorted_ans = [h.ys[context_size:] for h in best_hyps] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index c44db0206..eb22daefe 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -125,6 +125,7 @@ For example: import argparse import logging import math +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -146,6 +147,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import ContextGraph from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -353,6 +355,21 @@ def get_parser(): Used only when the decoding method is fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help="", + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help="", + ) + add_model_arguments(parser) return parser @@ -365,6 +382,7 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -492,6 +510,7 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + context_graph=context_graph, return_timestamps=True, ) else: @@ -556,6 +575,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: """Decode dataset. @@ -620,6 +640,7 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, + context_graph=context_graph, ) for name, (hyps, timestamps_hyp) in hyps_dict.items(): @@ -886,6 +907,18 @@ def main(): decoding_graph = None word_table = None + if params.decoding_method == "modified_beam_search": + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(line.strip()) + context_graph = ContextGraph(params.context_score) + context_graph.build_context_graph(contexts, sp) + else: + context_graph = None + else: + context_graph = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -899,8 +932,11 @@ def main(): test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_other_dl = librispeech.test_dataloaders(test_other_cuts) - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + test_book_cuts = librispeech.test_book_cuts() + test_book_dl = librispeech.test_dataloaders(test_book_cuts) + + test_sets = ["test-book", "test-clean", "test-other"] + test_dl = [test_book_dl, test_clean_dl, test_other_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( @@ -910,6 +946,7 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + context_graph=context_graph, ) save_results( diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index c5787835d..ac90daafe 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -445,6 +445,13 @@ class LibriSpeechAsrDataModule: self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" ) + @lru_cache() + def test_book_cuts(self) -> CutSet: + logging.info("About to get test-books cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libri_books_feats.jsonl.gz" + ) + @lru_cache() def test_other_cuts(self) -> CutSet: logging.info("About to get test-other cuts") diff --git a/icefall/__init__.py b/icefall/__init__.py index 82d21706c..1dbe4b312 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -17,6 +17,8 @@ from .checkpoint import ( save_checkpoint_with_global_batch_idx, ) +from .context_graph import ContextGraph, ContextState + from .decode import ( get_lattice, nbest_decoding, diff --git a/icefall/context_graph.py b/icefall/context_graph.py new file mode 100644 index 000000000..76e7808ad --- /dev/null +++ b/icefall/context_graph.py @@ -0,0 +1,116 @@ +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import List +import argparse +import kaldifst +import sentencepiece as spm + +from icefall.utils import is_module_available + + +@dataclass +class ContextState: + state_id: int = 0 + score: float = 0.0 + + +class ContextGraph: + def __init__(self, context_score: float = 1): + self.context_score = context_score + + def build_context_graph(self, contexts: List[str], sp: spm.SentencePieceProcessor): + + contexts_bpe = sp.encode(contexts) + graph = kaldifst.StdVectorFst() + start_state = ( + graph.add_state() + ) # 1st state will be state 0 (returned by add_state) + assert start_state == 0, start_state + graph.start = 0 # set the start state to 0 + graph.set_final(start_state, weight=0) # weight is in log space + + for bpe_ids in contexts_bpe: + prev_state = start_state + next_state = start_state + backoff_score = 0 + for i in range(len(bpe_ids)): + score = self.context_score + next_state = graph.add_state() if i < len(bpe_ids) - 1 else start_state + graph.add_arc( + state=prev_state, + arc=kaldifst.StdArc( + ilabel=bpe_ids[i], + olabel=bpe_ids[i], + weight=score, + nextstate=next_state, + ), + ) + if i > 0: + graph.add_arc( + state=prev_state, + arc=kaldifst.StdArc( + ilabel=0, + olabel=0, + weight=-backoff_score, + nextstate=start_state, + ), + ) + prev_state = next_state + backoff_score += score + self.graph = kaldifst.determinize(graph) + + def get_next_state(self, state_id: int, label: int) -> ContextState: + next_state = 0 + score = 0 + for arc in kaldifst.ArcIterator(self.graph, state_id): + if arc.ilabel == 0: + score = arc.weight.value + elif arc.ilabel == label: + next_state = arc.nextstate + score = arc.weight.value + break + return ContextState( + state_id=next_state, + score=score, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--bpe_model", + type=str, + help="Path to bpe model", + ) + args = parser.parse_args() + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"] + context_graph = ContextGraph() + context_graph.build_context_graph(contexts, sp) + + if not is_module_available("graphviz"): + raise ValueError("Please 'pip install graphviz' first.") + import graphviz + + fst_dot = kaldifst.draw(context_graph.graph, acceptor=False, portrait=True) + fst_source = graphviz.Source(fst_dot) + fst_source.render(outfile="context_graph.svg")