mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add context biasing for librispeech
This commit is contained in:
parent
8aaa9761e4
commit
1f1e28c1ad
@ -24,7 +24,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
from icefall import NgramLm, NgramLmStateCost
|
from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
|
||||||
from icefall.decode import Nbest, one_best_decoding
|
from icefall.decode import Nbest, one_best_decoding
|
||||||
from icefall.lm_wrapper import LmScorer
|
from icefall.lm_wrapper import LmScorer
|
||||||
from icefall.rnn_lm.model import RnnLmModel
|
from icefall.rnn_lm.model import RnnLmModel
|
||||||
@ -742,6 +742,9 @@ class Hypothesis:
|
|||||||
# N-gram LM state
|
# N-gram LM state
|
||||||
state_cost: Optional[NgramLmStateCost] = None
|
state_cost: Optional[NgramLmStateCost] = None
|
||||||
|
|
||||||
|
# Context graph state
|
||||||
|
context_state: Optional[ContextState] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self) -> str:
|
def key(self) -> str:
|
||||||
"""Return a string representation of self.ys"""
|
"""Return a string representation of self.ys"""
|
||||||
@ -883,6 +886,7 @@ def modified_beam_search(
|
|||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
return_timestamps: bool = False,
|
return_timestamps: bool = False,
|
||||||
@ -934,6 +938,7 @@ def modified_beam_search(
|
|||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[blank_id] * context_size,
|
ys=[blank_id] * context_size,
|
||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
context_state=ContextState(state_id=0),
|
||||||
timestamp=[],
|
timestamp=[],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -1017,17 +1022,53 @@ def modified_beam_search(
|
|||||||
new_ys = hyp.ys[:]
|
new_ys = hyp.ys[:]
|
||||||
new_token = topk_token_indexes[k]
|
new_token = topk_token_indexes[k]
|
||||||
new_timestamp = hyp.timestamp[:]
|
new_timestamp = hyp.timestamp[:]
|
||||||
|
new_context_state = None
|
||||||
if new_token not in (blank_id, unk_id):
|
if new_token not in (blank_id, unk_id):
|
||||||
new_ys.append(new_token)
|
new_ys.append(new_token)
|
||||||
new_timestamp.append(t)
|
new_timestamp.append(t)
|
||||||
|
if context_graph is not None:
|
||||||
new_log_prob = topk_log_probs[k]
|
new_context_state = context_graph.get_next_state(
|
||||||
|
hyp.context_state.state_id, new_token
|
||||||
|
)
|
||||||
|
new_log_prob = topk_log_probs[k] + (
|
||||||
|
0
|
||||||
|
if new_context_state is None
|
||||||
|
else new_context_state.score
|
||||||
|
)
|
||||||
new_hyp = Hypothesis(
|
new_hyp = Hypothesis(
|
||||||
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
ys=new_ys,
|
||||||
|
log_prob=new_log_prob,
|
||||||
|
timestamp=new_timestamp,
|
||||||
|
context_state=hyp.context_state
|
||||||
|
if new_context_state is None
|
||||||
|
else new_context_state,
|
||||||
)
|
)
|
||||||
B[i].add(new_hyp)
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
B = B + finalized_B
|
B = B + finalized_B
|
||||||
|
|
||||||
|
# finalize context_state, if the matched contexts do not reach final state
|
||||||
|
# we need to add the score on the corresponding backoff arc
|
||||||
|
if context_graph is not None:
|
||||||
|
finalized_B = [HypothesisList() for _ in range(len(B))]
|
||||||
|
for i, hyps in enumerate(B):
|
||||||
|
for hyp in list(hyps):
|
||||||
|
if hyp.context_state.state_id != 0:
|
||||||
|
new_context_state = context_graph.get_next_state(
|
||||||
|
hyp.context_state.state_id, 0
|
||||||
|
)
|
||||||
|
finalized_B[i].add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=hyp.ys,
|
||||||
|
log_prob=hyp.log_prob + new_context_state.score,
|
||||||
|
timestamp=hyp.timestamp,
|
||||||
|
context_state=new_context_state,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
finalized_B[i].add(hyp)
|
||||||
|
B = finalized_B
|
||||||
|
|
||||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
|
|
||||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
|||||||
@ -125,6 +125,7 @@ For example:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
@ -146,6 +147,7 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall import ContextGraph
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
@ -353,6 +355,21 @@ def get_parser():
|
|||||||
Used only when the decoding method is fast_beam_search_nbest,
|
Used only when the decoding method is fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-score",
|
||||||
|
type=float,
|
||||||
|
default=2,
|
||||||
|
help="",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -365,6 +382,7 @@ def decode_one_batch(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -492,6 +510,7 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
context_graph=context_graph,
|
||||||
return_timestamps=True,
|
return_timestamps=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -556,6 +575,7 @@ def decode_dataset(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -620,6 +640,7 @@ def decode_dataset(
|
|||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
for name, (hyps, timestamps_hyp) in hyps_dict.items():
|
||||||
@ -886,6 +907,18 @@ def main():
|
|||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
word_table = None
|
word_table = None
|
||||||
|
|
||||||
|
if params.decoding_method == "modified_beam_search":
|
||||||
|
if os.path.exists(params.context_file):
|
||||||
|
contexts = []
|
||||||
|
for line in open(params.context_file).readlines():
|
||||||
|
contexts.append(line.strip())
|
||||||
|
context_graph = ContextGraph(params.context_score)
|
||||||
|
context_graph.build_context_graph(contexts, sp)
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
else:
|
||||||
|
context_graph = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -899,8 +932,11 @@ def main():
|
|||||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||||
|
|
||||||
test_sets = ["test-clean", "test-other"]
|
test_book_cuts = librispeech.test_book_cuts()
|
||||||
test_dl = [test_clean_dl, test_other_dl]
|
test_book_dl = librispeech.test_dataloaders(test_book_cuts)
|
||||||
|
|
||||||
|
test_sets = ["test-book", "test-clean", "test-other"]
|
||||||
|
test_dl = [test_book_dl, test_clean_dl, test_other_dl]
|
||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
@ -910,6 +946,7 @@ def main():
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
|||||||
@ -445,6 +445,13 @@ class LibriSpeechAsrDataModule:
|
|||||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_book_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test-books cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libri_books_feats.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_other_cuts(self) -> CutSet:
|
def test_other_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test-other cuts")
|
logging.info("About to get test-other cuts")
|
||||||
|
|||||||
@ -17,6 +17,8 @@ from .checkpoint import (
|
|||||||
save_checkpoint_with_global_batch_idx,
|
save_checkpoint_with_global_batch_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .context_graph import ContextGraph, ContextState
|
||||||
|
|
||||||
from .decode import (
|
from .decode import (
|
||||||
get_lattice,
|
get_lattice,
|
||||||
nbest_decoding,
|
nbest_decoding,
|
||||||
|
|||||||
116
icefall/context_graph.py
Normal file
116
icefall/context_graph.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
import argparse
|
||||||
|
import kaldifst
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
from icefall.utils import is_module_available
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ContextState:
|
||||||
|
state_id: int = 0
|
||||||
|
score: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class ContextGraph:
|
||||||
|
def __init__(self, context_score: float = 1):
|
||||||
|
self.context_score = context_score
|
||||||
|
|
||||||
|
def build_context_graph(self, contexts: List[str], sp: spm.SentencePieceProcessor):
|
||||||
|
|
||||||
|
contexts_bpe = sp.encode(contexts)
|
||||||
|
graph = kaldifst.StdVectorFst()
|
||||||
|
start_state = (
|
||||||
|
graph.add_state()
|
||||||
|
) # 1st state will be state 0 (returned by add_state)
|
||||||
|
assert start_state == 0, start_state
|
||||||
|
graph.start = 0 # set the start state to 0
|
||||||
|
graph.set_final(start_state, weight=0) # weight is in log space
|
||||||
|
|
||||||
|
for bpe_ids in contexts_bpe:
|
||||||
|
prev_state = start_state
|
||||||
|
next_state = start_state
|
||||||
|
backoff_score = 0
|
||||||
|
for i in range(len(bpe_ids)):
|
||||||
|
score = self.context_score
|
||||||
|
next_state = graph.add_state() if i < len(bpe_ids) - 1 else start_state
|
||||||
|
graph.add_arc(
|
||||||
|
state=prev_state,
|
||||||
|
arc=kaldifst.StdArc(
|
||||||
|
ilabel=bpe_ids[i],
|
||||||
|
olabel=bpe_ids[i],
|
||||||
|
weight=score,
|
||||||
|
nextstate=next_state,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if i > 0:
|
||||||
|
graph.add_arc(
|
||||||
|
state=prev_state,
|
||||||
|
arc=kaldifst.StdArc(
|
||||||
|
ilabel=0,
|
||||||
|
olabel=0,
|
||||||
|
weight=-backoff_score,
|
||||||
|
nextstate=start_state,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
prev_state = next_state
|
||||||
|
backoff_score += score
|
||||||
|
self.graph = kaldifst.determinize(graph)
|
||||||
|
|
||||||
|
def get_next_state(self, state_id: int, label: int) -> ContextState:
|
||||||
|
next_state = 0
|
||||||
|
score = 0
|
||||||
|
for arc in kaldifst.ArcIterator(self.graph, state_id):
|
||||||
|
if arc.ilabel == 0:
|
||||||
|
score = arc.weight.value
|
||||||
|
elif arc.ilabel == label:
|
||||||
|
next_state = arc.nextstate
|
||||||
|
score = arc.weight.value
|
||||||
|
break
|
||||||
|
return ContextState(
|
||||||
|
state_id=next_state,
|
||||||
|
score=score,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe_model",
|
||||||
|
type=str,
|
||||||
|
help="Path to bpe model",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(args.bpe_model)
|
||||||
|
|
||||||
|
contexts = ["LOVE CHINA", "HELLO WORLD", "LOVE WORLD"]
|
||||||
|
context_graph = ContextGraph()
|
||||||
|
context_graph.build_context_graph(contexts, sp)
|
||||||
|
|
||||||
|
if not is_module_available("graphviz"):
|
||||||
|
raise ValueError("Please 'pip install graphviz' first.")
|
||||||
|
import graphviz
|
||||||
|
|
||||||
|
fst_dot = kaldifst.draw(context_graph.graph, acceptor=False, portrait=True)
|
||||||
|
fst_source = graphviz.Source(fst_dot)
|
||||||
|
fst_source.render(outfile="context_graph.svg")
|
||||||
Loading…
x
Reference in New Issue
Block a user