Add context biasing for librispeech

This commit is contained in:
pkufool 2023-03-14 19:14:24 +08:00
parent 8aaa9761e4
commit 1f1e28c1ad
5 changed files with 209 additions and 6 deletions

View File

@ -24,7 +24,7 @@ import sentencepiece as spm
import torch import torch
from model import Transducer from model import Transducer
from icefall import NgramLm, NgramLmStateCost from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding from icefall.decode import Nbest, one_best_decoding
from icefall.lm_wrapper import LmScorer from icefall.lm_wrapper import LmScorer
from icefall.rnn_lm.model import RnnLmModel from icefall.rnn_lm.model import RnnLmModel
@ -742,6 +742,9 @@ class Hypothesis:
# N-gram LM state # N-gram LM state
state_cost: Optional[NgramLmStateCost] = None state_cost: Optional[NgramLmStateCost] = None
# Context graph state
context_state: Optional[ContextState] = None
@property @property
def key(self) -> str: def key(self) -> str:
"""Return a string representation of self.ys""" """Return a string representation of self.ys"""
@ -883,6 +886,7 @@ def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
context_graph: Optional[ContextGraph] = None,
beam: int = 4, beam: int = 4,
temperature: float = 1.0, temperature: float = 1.0,
return_timestamps: bool = False, return_timestamps: bool = False,
@ -934,6 +938,7 @@ def modified_beam_search(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=ContextState(state_id=0),
timestamp=[], timestamp=[],
) )
) )
@ -1017,17 +1022,53 @@ def modified_beam_search(
new_ys = hyp.ys[:] new_ys = hyp.ys[:]
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:] new_timestamp = hyp.timestamp[:]
new_context_state = None
if new_token not in (blank_id, unk_id): if new_token not in (blank_id, unk_id):
new_ys.append(new_token) new_ys.append(new_token)
new_timestamp.append(t) new_timestamp.append(t)
if context_graph is not None:
new_log_prob = topk_log_probs[k] new_context_state = context_graph.get_next_state(
hyp.context_state.state_id, new_token
)
new_log_prob = topk_log_probs[k] + (
0
if new_context_state is None
else new_context_state.score
)
new_hyp = Hypothesis( new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp ys=new_ys,
log_prob=new_log_prob,
timestamp=new_timestamp,
context_state=hyp.context_state
if new_context_state is None
else new_context_state,
) )
B[i].add(new_hyp) B[i].add(new_hyp)
B = B + finalized_B B = B + finalized_B
# finalize context_state, if the matched contexts do not reach final state
# we need to add the score on the corresponding backoff arc
if context_graph is not None:
finalized_B = [HypothesisList() for _ in range(len(B))]
for i, hyps in enumerate(B):
for hyp in list(hyps):
if hyp.context_state.state_id != 0:
new_context_state = context_graph.get_next_state(
hyp.context_state.state_id, 0
)
finalized_B[i].add(
Hypothesis(
ys=hyp.ys,
log_prob=hyp.log_prob + new_context_state.score,
timestamp=hyp.timestamp,
context_state=new_context_state,
)
)
else:
finalized_B[i].add(hyp)
B = finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B] best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps] sorted_ans = [h.ys[context_size:] for h in best_hyps]

View File

@ -125,6 +125,7 @@ For example:
import argparse import argparse
import logging import logging
import math import math
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -146,6 +147,7 @@ from beam_search import (
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall import ContextGraph
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -353,6 +355,21 @@ def get_parser():
Used only when the decoding method is fast_beam_search_nbest, Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
parser.add_argument(
"--context-score",
type=float,
default=2,
help="",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -365,6 +382,7 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]: ) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -492,6 +510,7 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
context_graph=context_graph,
return_timestamps=True, return_timestamps=True,
) )
else: else:
@ -556,6 +575,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
"""Decode dataset. """Decode dataset.
@ -620,6 +640,7 @@ def decode_dataset(
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
word_table=word_table, word_table=word_table,
batch=batch, batch=batch,
context_graph=context_graph,
) )
for name, (hyps, timestamps_hyp) in hyps_dict.items(): for name, (hyps, timestamps_hyp) in hyps_dict.items():
@ -886,6 +907,18 @@ def main():
decoding_graph = None decoding_graph = None
word_table = None word_table = None
if params.decoding_method == "modified_beam_search":
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build_context_graph(contexts, sp)
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -899,8 +932,11 @@ def main():
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts) test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"] test_book_cuts = librispeech.test_book_cuts()
test_dl = [test_clean_dl, test_other_dl] test_book_dl = librispeech.test_dataloaders(test_book_cuts)
test_sets = ["test-book", "test-clean", "test-other"]
test_dl = [test_book_dl, test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(
@ -910,6 +946,7 @@ def main():
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
context_graph=context_graph,
) )
save_results( save_results(

View File

@ -445,6 +445,13 @@ class LibriSpeechAsrDataModule:
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
) )
@lru_cache()
def test_book_cuts(self) -> CutSet:
logging.info("About to get test-books cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libri_books_feats.jsonl.gz"
)
@lru_cache() @lru_cache()
def test_other_cuts(self) -> CutSet: def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts") logging.info("About to get test-other cuts")

View File

@ -17,6 +17,8 @@ from .checkpoint import (
save_checkpoint_with_global_batch_idx, save_checkpoint_with_global_batch_idx,
) )
from .context_graph import ContextGraph, ContextState
from .decode import ( from .decode import (
get_lattice, get_lattice,
nbest_decoding, nbest_decoding,

116
icefall/context_graph.py Normal file
View File

@ -0,0 +1,116 @@
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List
import argparse
import kaldifst
import sentencepiece as spm
from icefall.utils import is_module_available
@dataclass
class ContextState:
state_id: int = 0
score: float = 0.0
class ContextGraph:
def __init__(self, context_score: float = 1):
self.context_score = context_score
def build_context_graph(self, contexts: List[str], sp: spm.SentencePieceProcessor):
contexts_bpe = sp.encode(contexts)
graph = kaldifst.StdVectorFst()
start_state = (
graph.add_state()
) # 1st state will be state 0 (returned by add_state)
assert start_state == 0, start_state
graph.start = 0 # set the start state to 0
graph.set_final(start_state, weight=0) # weight is in log space
for bpe_ids in contexts_bpe:
prev_state = start_state
next_state = start_state
backoff_score = 0
for i in range(len(bpe_ids)):
score = self.context_score
next_state = graph.add_state() if i < len(bpe_ids) - 1 else start_state
graph.add_arc(
state=prev_state,
arc=kaldifst.StdArc(
ilabel=bpe_ids[i],
olabel=bpe_ids[i],
weight=score,
nextstate=next_state,
),
)
if i > 0:
graph.add_arc(
state=prev_state,
arc=kaldifst.StdArc(
ilabel=0,
olabel=0,
weight=-backoff_score,
nextstate=start_state,
),
)
prev_state = next_state
backoff_score += score
self.graph = kaldifst.determinize(graph)
def get_next_state(self, state_id: int, label: int) -> ContextState:
next_state = 0
score = 0
for arc in kaldifst.ArcIterator(self.graph, state_id):
if arc.ilabel == 0:
score = arc.weight.value
elif arc.ilabel == label:
next_state = arc.nextstate
score = arc.weight.value
break
return ContextState(
state_id=next_state,
score=score,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe_model",
type=str,
help="Path to bpe model",
)
args = parser.parse_args()
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"]
context_graph = ContextGraph()
context_graph.build_context_graph(contexts, sp)
if not is_module_available("graphviz"):
raise ValueError("Please 'pip install graphviz' first.")
import graphviz
fst_dot = kaldifst.draw(context_graph.graph, acceptor=False, portrait=True)
fst_source = graphviz.Source(fst_dot)
fst_source.render(outfile="context_graph.svg")