mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add context biasing for librispeech
This commit is contained in:
parent
8aaa9761e4
commit
1f1e28c1ad
@ -24,7 +24,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
from model import Transducer
|
||||
|
||||
from icefall import NgramLm, NgramLmStateCost
|
||||
from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
from icefall.lm_wrapper import LmScorer
|
||||
from icefall.rnn_lm.model import RnnLmModel
|
||||
@ -742,6 +742,9 @@ class Hypothesis:
|
||||
# N-gram LM state
|
||||
state_cost: Optional[NgramLmStateCost] = None
|
||||
|
||||
# Context graph state
|
||||
context_state: Optional[ContextState] = None
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Return a string representation of self.ys"""
|
||||
@ -883,6 +886,7 @@ def modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
@ -934,6 +938,7 @@ def modified_beam_search(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
context_state=ContextState(state_id=0),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
@ -1017,17 +1022,53 @@ def modified_beam_search(
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
new_context_state = None
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
if context_graph is not None:
|
||||
new_context_state = context_graph.get_next_state(
|
||||
hyp.context_state.state_id, new_token
|
||||
)
|
||||
new_log_prob = topk_log_probs[k] + (
|
||||
0
|
||||
if new_context_state is None
|
||||
else new_context_state.score
|
||||
)
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||
ys=new_ys,
|
||||
log_prob=new_log_prob,
|
||||
timestamp=new_timestamp,
|
||||
context_state=hyp.context_state
|
||||
if new_context_state is None
|
||||
else new_context_state,
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
B = B + finalized_B
|
||||
|
||||
# finalize context_state, if the matched contexts do not reach final state
|
||||
# we need to add the score on the corresponding backoff arc
|
||||
if context_graph is not None:
|
||||
finalized_B = [HypothesisList() for _ in range(len(B))]
|
||||
for i, hyps in enumerate(B):
|
||||
for hyp in list(hyps):
|
||||
if hyp.context_state.state_id != 0:
|
||||
new_context_state = context_graph.get_next_state(
|
||||
hyp.context_state.state_id, 0
|
||||
)
|
||||
finalized_B[i].add(
|
||||
Hypothesis(
|
||||
ys=hyp.ys,
|
||||
log_prob=hyp.log_prob + new_context_state.score,
|
||||
timestamp=hyp.timestamp,
|
||||
context_state=new_context_state,
|
||||
)
|
||||
)
|
||||
else:
|
||||
finalized_B[i].add(hyp)
|
||||
B = finalized_B
|
||||
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||
|
||||
@ -125,6 +125,7 @@ For example:
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -146,6 +147,7 @@ from beam_search import (
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import ContextGraph
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -353,6 +355,21 @@ def get_parser():
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-score",
|
||||
type=float,
|
||||
default=2,
|
||||
help="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -365,6 +382,7 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -492,6 +510,7 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
context_graph=context_graph,
|
||||
return_timestamps=True,
|
||||
)
|
||||
else:
|
||||
@ -556,6 +575,7 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -620,6 +640,7 @@ def decode_dataset(
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||
@ -886,6 +907,18 @@ def main():
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
if params.decoding_method == "modified_beam_search":
|
||||
if os.path.exists(params.context_file):
|
||||
contexts = []
|
||||
for line in open(params.context_file).readlines():
|
||||
contexts.append(line.strip())
|
||||
context_graph = ContextGraph(params.context_score)
|
||||
context_graph.build_context_graph(contexts, sp)
|
||||
else:
|
||||
context_graph = None
|
||||
else:
|
||||
context_graph = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -899,8 +932,11 @@ def main():
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
test_book_cuts = librispeech.test_book_cuts()
|
||||
test_book_dl = librispeech.test_dataloaders(test_book_cuts)
|
||||
|
||||
test_sets = ["test-book", "test-clean", "test-other"]
|
||||
test_dl = [test_book_dl, test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
@ -910,6 +946,7 @@ def main():
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
||||
@ -445,6 +445,13 @@ class LibriSpeechAsrDataModule:
|
||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_book_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-books cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "libri_books_feats.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
|
||||
@ -17,6 +17,8 @@ from .checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
)
|
||||
|
||||
from .context_graph import ContextGraph, ContextState
|
||||
|
||||
from .decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
|
||||
116
icefall/context_graph.py
Normal file
116
icefall/context_graph.py
Normal file
@ -0,0 +1,116 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
import argparse
|
||||
import kaldifst
|
||||
import sentencepiece as spm
|
||||
|
||||
from icefall.utils import is_module_available
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextState:
|
||||
state_id: int = 0
|
||||
score: float = 0.0
|
||||
|
||||
|
||||
class ContextGraph:
|
||||
def __init__(self, context_score: float = 1):
|
||||
self.context_score = context_score
|
||||
|
||||
def build_context_graph(self, contexts: List[str], sp: spm.SentencePieceProcessor):
|
||||
|
||||
contexts_bpe = sp.encode(contexts)
|
||||
graph = kaldifst.StdVectorFst()
|
||||
start_state = (
|
||||
graph.add_state()
|
||||
) # 1st state will be state 0 (returned by add_state)
|
||||
assert start_state == 0, start_state
|
||||
graph.start = 0 # set the start state to 0
|
||||
graph.set_final(start_state, weight=0) # weight is in log space
|
||||
|
||||
for bpe_ids in contexts_bpe:
|
||||
prev_state = start_state
|
||||
next_state = start_state
|
||||
backoff_score = 0
|
||||
for i in range(len(bpe_ids)):
|
||||
score = self.context_score
|
||||
next_state = graph.add_state() if i < len(bpe_ids) - 1 else start_state
|
||||
graph.add_arc(
|
||||
state=prev_state,
|
||||
arc=kaldifst.StdArc(
|
||||
ilabel=bpe_ids[i],
|
||||
olabel=bpe_ids[i],
|
||||
weight=score,
|
||||
nextstate=next_state,
|
||||
),
|
||||
)
|
||||
if i > 0:
|
||||
graph.add_arc(
|
||||
state=prev_state,
|
||||
arc=kaldifst.StdArc(
|
||||
ilabel=0,
|
||||
olabel=0,
|
||||
weight=-backoff_score,
|
||||
nextstate=start_state,
|
||||
),
|
||||
)
|
||||
prev_state = next_state
|
||||
backoff_score += score
|
||||
self.graph = kaldifst.determinize(graph)
|
||||
|
||||
def get_next_state(self, state_id: int, label: int) -> ContextState:
|
||||
next_state = 0
|
||||
score = 0
|
||||
for arc in kaldifst.ArcIterator(self.graph, state_id):
|
||||
if arc.ilabel == 0:
|
||||
score = arc.weight.value
|
||||
elif arc.ilabel == label:
|
||||
next_state = arc.nextstate
|
||||
score = arc.weight.value
|
||||
break
|
||||
return ContextState(
|
||||
state_id=next_state,
|
||||
score=score,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--bpe_model",
|
||||
type=str,
|
||||
help="Path to bpe model",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"]
|
||||
context_graph = ContextGraph()
|
||||
context_graph.build_context_graph(contexts, sp)
|
||||
|
||||
if not is_module_available("graphviz"):
|
||||
raise ValueError("Please 'pip install graphviz' first.")
|
||||
import graphviz
|
||||
|
||||
fst_dot = kaldifst.draw(context_graph.graph, acceptor=False, portrait=True)
|
||||
fst_source = graphviz.Source(fst_dot)
|
||||
fst_source.render(outfile="context_graph.svg")
|
||||
Loading…
x
Reference in New Issue
Block a user