Add context biasing for librispeech

This commit is contained in:
pkufool 2023-03-14 19:14:24 +08:00
parent 8aaa9761e4
commit 1f1e28c1ad
5 changed files with 209 additions and 6 deletions

View File

@ -24,7 +24,7 @@ import sentencepiece as spm
import torch
from model import Transducer
from icefall import NgramLm, NgramLmStateCost
from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding
from icefall.lm_wrapper import LmScorer
from icefall.rnn_lm.model import RnnLmModel
@ -742,6 +742,9 @@ class Hypothesis:
# N-gram LM state
state_cost: Optional[NgramLmStateCost] = None
# Context graph state
context_state: Optional[ContextState] = None
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
@ -883,6 +886,7 @@ def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
context_graph: Optional[ContextGraph] = None,
beam: int = 4,
temperature: float = 1.0,
return_timestamps: bool = False,
@ -934,6 +938,7 @@ def modified_beam_search(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
context_state=ContextState(state_id=0),
timestamp=[],
)
)
@ -1017,17 +1022,53 @@ def modified_beam_search(
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
new_context_state = None
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_timestamp.append(t)
new_log_prob = topk_log_probs[k]
if context_graph is not None:
new_context_state = context_graph.get_next_state(
hyp.context_state.state_id, new_token
)
new_log_prob = topk_log_probs[k] + (
0
if new_context_state is None
else new_context_state.score
)
new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
ys=new_ys,
log_prob=new_log_prob,
timestamp=new_timestamp,
context_state=hyp.context_state
if new_context_state is None
else new_context_state,
)
B[i].add(new_hyp)
B = B + finalized_B
# finalize context_state, if the matched contexts do not reach final state
# we need to add the score on the corresponding backoff arc
if context_graph is not None:
finalized_B = [HypothesisList() for _ in range(len(B))]
for i, hyps in enumerate(B):
for hyp in list(hyps):
if hyp.context_state.state_id != 0:
new_context_state = context_graph.get_next_state(
hyp.context_state.state_id, 0
)
finalized_B[i].add(
Hypothesis(
ys=hyp.ys,
log_prob=hyp.log_prob + new_context_state.score,
timestamp=hyp.timestamp,
context_state=new_context_state,
)
)
else:
finalized_B[i].add(hyp)
B = finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]

View File

@ -125,6 +125,7 @@ For example:
import argparse
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -146,6 +147,7 @@ from beam_search import (
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import ContextGraph
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -353,6 +355,21 @@ def get_parser():
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-score",
type=float,
default=2,
help="",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="",
)
add_model_arguments(parser)
return parser
@ -365,6 +382,7 @@ def decode_one_batch(
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -492,6 +510,7 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
context_graph=context_graph,
return_timestamps=True,
)
else:
@ -556,6 +575,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
"""Decode dataset.
@ -620,6 +640,7 @@ def decode_dataset(
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
context_graph=context_graph,
)
for name, (hyps, timestamps_hyp) in hyps_dict.items():
@ -886,6 +907,18 @@ def main():
decoding_graph = None
word_table = None
if params.decoding_method == "modified_beam_search":
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build_context_graph(contexts, sp)
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -899,8 +932,11 @@ def main():
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
test_book_cuts = librispeech.test_book_cuts()
test_book_dl = librispeech.test_dataloaders(test_book_cuts)
test_sets = ["test-book", "test-clean", "test-other"]
test_dl = [test_book_dl, test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
@ -910,6 +946,7 @@ def main():
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
context_graph=context_graph,
)
save_results(

View File

@ -445,6 +445,13 @@ class LibriSpeechAsrDataModule:
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
)
@lru_cache()
def test_book_cuts(self) -> CutSet:
logging.info("About to get test-books cuts")
return load_manifest_lazy(
self.args.manifest_dir / "libri_books_feats.jsonl.gz"
)
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get test-other cuts")

View File

@ -17,6 +17,8 @@ from .checkpoint import (
save_checkpoint_with_global_batch_idx,
)
from .context_graph import ContextGraph, ContextState
from .decode import (
get_lattice,
nbest_decoding,

116
icefall/context_graph.py Normal file
View File

@ -0,0 +1,116 @@
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List
import argparse
import kaldifst
import sentencepiece as spm
from icefall.utils import is_module_available
@dataclass
class ContextState:
state_id: int = 0
score: float = 0.0
class ContextGraph:
def __init__(self, context_score: float = 1):
self.context_score = context_score
def build_context_graph(self, contexts: List[str], sp: spm.SentencePieceProcessor):
contexts_bpe = sp.encode(contexts)
graph = kaldifst.StdVectorFst()
start_state = (
graph.add_state()
) # 1st state will be state 0 (returned by add_state)
assert start_state == 0, start_state
graph.start = 0 # set the start state to 0
graph.set_final(start_state, weight=0) # weight is in log space
for bpe_ids in contexts_bpe:
prev_state = start_state
next_state = start_state
backoff_score = 0
for i in range(len(bpe_ids)):
score = self.context_score
next_state = graph.add_state() if i < len(bpe_ids) - 1 else start_state
graph.add_arc(
state=prev_state,
arc=kaldifst.StdArc(
ilabel=bpe_ids[i],
olabel=bpe_ids[i],
weight=score,
nextstate=next_state,
),
)
if i > 0:
graph.add_arc(
state=prev_state,
arc=kaldifst.StdArc(
ilabel=0,
olabel=0,
weight=-backoff_score,
nextstate=start_state,
),
)
prev_state = next_state
backoff_score += score
self.graph = kaldifst.determinize(graph)
def get_next_state(self, state_id: int, label: int) -> ContextState:
next_state = 0
score = 0
for arc in kaldifst.ArcIterator(self.graph, state_id):
if arc.ilabel == 0:
score = arc.weight.value
elif arc.ilabel == label:
next_state = arc.nextstate
score = arc.weight.value
break
return ContextState(
state_id=next_state,
score=score,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe_model",
type=str,
help="Path to bpe model",
)
args = parser.parse_args()
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"]
context_graph = ContextGraph()
context_graph.build_context_graph(contexts, sp)
if not is_module_available("graphviz"):
raise ValueError("Please 'pip install graphviz' first.")
import graphviz
fst_dot = kaldifst.draw(context_graph.graph, acceptor=False, portrait=True)
fst_source = graphviz.Source(fst_dot)
fst_source.render(outfile="context_graph.svg")