WIP: Use shallow fusion in modified beam search.

This commit is contained in:
Fangjun Kuang 2022-02-08 20:40:45 +08:00
parent 27fa5f05d3
commit 954b4efff3
5 changed files with 417 additions and 4 deletions

View File

@ -0,0 +1,160 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates LG from
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt
The generated LG is saved in $lang_dir/LG.fst
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_LG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_500.
Return:
An FST representing LG.
"""
tokens = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt")
assert "#0" in tokens
first_token_disambig_id = tokens["#0"]
logging.info(f"first token disambig ID: {first_token_disambig_id}")
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G_3_gram.pt").is_file():
logging.info("Loading pre-compiled G_3_gram")
d = torch.load("data/lm/G_3_gram.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
del G.aux_labels
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Composing L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}, num_arcs: {LG.num_arcs}")
del LG.aux_labels
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Determinizing LG")
LG = k2.determinize(LG)
logging.info(
f"LG shape after k2.determinize: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info(
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set LG.properties to None
LG.__dict__["_properties"] = None
logging.info("Removing epsilons")
LG = k2.remove_epsilon(LG)
logging.info(
f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Connecting")
LG = k2.connect(LG)
logging.info(
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logging.info(f"LG properties: {LG.properties_str}")
# Possible properties is:
# "Valid|Nonempty|ArcSorted|EpsilonFree|MaybeAccessible|MaybeCoaccessible"
logging.info("Caution: LG is not deterministic!!!")
return LG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
out_filename = lang_dir / "LG.pt"
if out_filename.is_file():
logging.info(f"{out_filename} already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
LG = compile_LG(lang_dir)
logging.info(f"Saving LG to {out_filename}")
torch.save(LG.as_dict(), out_filename)
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,79 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./local/test_compile_lg.py
"""
from pathlib import Path
from typing import List
import k2
import sentencepiece as spm
import torch
lang_dir = Path("./data/lang_bpe_500")
def get_word_ids(word_table: k2.SymbolTable, s: str) -> List[int]:
"""
Args:
word_table:
Word symbol table.
s:
A string consisting of space(s) separated words.
Returns:
Return a list of word IDs.
"""
ans = []
for w in s.split():
ans.append(word_table[w])
return ans
def main():
assert lang_dir.exists(), f"{lang_dir} does not exist!"
LG = k2.Fsa.from_dict(torch.load(f"{lang_dir}/LG.pt", map_location="cpu"))
sp = spm.SentencePieceProcessor()
sp.load(f"{lang_dir}/bpe.model")
word_table = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
s = "HELLO WORLD"
token_ids = sp.encode(s)
token_fsa = k2.linear_fsa(token_ids)
fsa = k2.intersect(LG, token_fsa)
fsa = k2.connect(fsa)
print(k2.to_dot(fsa))
print(fsa.properties_str)
print(LG.properties_str)
# You can use https://dreampuf.github.io/GraphvizOnline/
# to visualize the output.
#
# You can see that the resulting fsa is not deterministic
# Note: LG is non-deterministic
#
# See https://shorturl.at/uIL69
# for visualization of the above fsa.
if __name__ == "__main__":
main()

View File

@ -17,8 +17,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import k2
import torch import torch
from model import Transducer from model import Transducer
from shallow_fusion import shallow_fusion
def greedy_search( def greedy_search(
@ -111,6 +113,13 @@ class Hypothesis:
# It contains only one entry. # It contains only one entry.
log_prob: torch.Tensor log_prob: torch.Tensor
# Used for shallow fusion
# The key of the dict is a state index into LG
# while the corresponding value is the LM score
# reaching this state.
# Note: The value tensor contains only a single entry
ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = (None,)
@property @property
def key(self) -> str: def key(self) -> str:
"""Return a string representation of self.ys""" """Return a string representation of self.ys"""
@ -149,6 +158,15 @@ class HypothesisList(object):
torch.logaddexp( torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
) )
if hyp.ngram_state_and_scores is not None:
for state, score in hyp.ngram_state_and_scores.items():
if (
state in old_hyp.ngram_state_and_scores
and score > old_hyp.ngram_state_and_scores[state]
):
old_hyp.ngram_state_and_scores[state] = score
else:
old_hyp.ngram_state_and_scores[state] = score
else: else:
self._data[key] = hyp self._data[key] = hyp
@ -318,6 +336,7 @@ def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
LG: Optional[k2.Fsa] = None,
) -> List[int]: ) -> List[int]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -328,9 +347,13 @@ def modified_beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam: beam:
Beam size. Beam size.
LG:
Optional. Used for shallow fusion.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
enable_shallow_fusion = LG is not None
ngram_lm_scale = 0.8
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
@ -350,10 +373,19 @@ def modified_beam_search(
T = encoder_out.size(1) T = encoder_out.size(1)
B = HypothesisList() B = HypothesisList()
if enable_shallow_fusion:
ngram_state_and_scores = {
0: torch.zeros(1, dtype=torch.float32, device=device)
}
else:
ngram_state_and_scores = None
B.add( B.add(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
ngram_state_and_scores=ngram_state_and_scores,
) )
) )
@ -411,9 +443,33 @@ def modified_beam_search(
new_token = topk_token_indexes[i] new_token = topk_token_indexes[i]
if new_token != blank_id: if new_token != blank_id:
new_ys.append(new_token) new_ys.append(new_token)
else:
ngram_state_and_scores = hyp.ngram_state_and_scores
new_log_prob = topk_log_probs[i] new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
if enable_shallow_fusion and new_token != blank_id:
ngram_state_and_scores = shallow_fusion(
LG, new_token, hyp.ngram_state_and_scores
)
if len(ngram_state_and_scores) == 0:
continue
max_ngram_score = max(ngram_state_and_scores.values())
new_log_prob += ngram_lm_scale * max_ngram_score
# TODO: Get the maximum scores in ngram_state_and_scores
# and add it to new_log_prob
new_hyp = Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
ngram_state_and_scores=ngram_state_and_scores,
)
B.add(new_hyp) B.add(new_hyp)
if len(B) == 0:
for h in A:
B.add(h)
best_hyp = B.get_most_probable(length_norm=True) best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks

View File

@ -40,8 +40,9 @@ import argparse
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -131,6 +132,13 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--LG",
type=str,
help="""Path to LG.pt for shallow fusion.
Used only when --decoding-method is modified_beam_search.""",
)
return parser return parser
@ -203,6 +211,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
LG: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -225,6 +234,9 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
LG:
Optional. Used for shallow fusion. Used only when params.decoding_method
is modified_beam_search.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -257,17 +269,24 @@ def decode_one_batch(
) )
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
) )
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search( hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
LG=LG,
) )
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.decode(hyp).split()) hyps.append(sp.decode(hyp).split())
for h in hyps:
print(" ".join(h))
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
@ -280,6 +299,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
LG: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -292,6 +312,9 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
LG:
Optional. Used for shallow fusion. Used only when params.decoding_method
is modified_beam_search.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -320,6 +343,7 @@ def decode_dataset(
model=model, model=model,
sp=sp, sp=sp,
batch=batch, batch=batch,
LG=LG,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -419,6 +443,21 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
if params.LG is not None:
assert (
params.decoding_method == "modified_beam_search"
), "--LG is used only when --decoding_method=modified_beam_search"
logging.info(f"Loading LG from {params.LG}")
LG = k2.Fsa.from_dict(torch.load(params.LG, map_location=device))
logging.info(f"LG properties: {LG.properties_str}")
logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}")
# If LG is created by local/compile_lg.py, then it should be epsilon
# free as well as arc sorted
assert "ArcSorted" in LG.properties_str
assert "EpsilonFree" in LG.properties_str
else:
LG = None
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
@ -467,6 +506,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
LG=LG,
) )
save_results( save_results(

View File

@ -0,0 +1,78 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import k2
import torch
def shallow_fusion(
LG: k2.Fsa,
token: int,
state_and_scores: Dict[int, torch.Tensor],
) -> Dict[int, torch.Tensor]:
"""
Args:
LG:
An n-gram. It should be arc sorted and epsilon free.
token:
The input token ID.
state_and_scores:
The keys contain the current state we are in and the
values are the LM log_prob for reaching the corresponding
states from the start state.
Returns:
Return a new state_and_scores.
"""
row_splits = LG.arcs.row_splits(1)
arcs = LG.arcs.values()
current_states = list(state_and_scores.keys())
ans = dict()
for s in current_states:
labels_begin = row_splits[s]
labels_end = row_splits[s + 1]
labels = LG.labels[labels_begin:labels_end].contiguous()
# As LG is not deterministic, there may be multiple
# out-going arcs that with label equal to "token"
#
# Note: LG is arc sorted!
left = torch.bucketize(token, labels, right=False)
right = torch.bucketize(token, labels, right=True)
if left >= right:
# There are no out-going arcs from this state
# that have label equal to "token"
continue
# Now we have
# labels[i] == token
# for
# left <= i < right
for i in range(left, right):
i += labels_begin
next_state = arcs[i][1].item()
score = LG.scores[i]
if next_state not in ans:
ans[next_state] = score
else:
ans[next_state] = max(score, ans[next_state])
return ans