WIP: Use shallow fusion in modified beam search.

This commit is contained in:
Fangjun Kuang 2022-02-08 20:40:45 +08:00
parent 27fa5f05d3
commit 954b4efff3
5 changed files with 417 additions and 4 deletions

View File

@ -0,0 +1,160 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates LG from
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt
The generated LG is saved in $lang_dir/LG.fst
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_LG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_500.
Return:
An FST representing LG.
"""
tokens = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt")
assert "#0" in tokens
first_token_disambig_id = tokens["#0"]
logging.info(f"first token disambig ID: {first_token_disambig_id}")
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G_3_gram.pt").is_file():
logging.info("Loading pre-compiled G_3_gram")
d = torch.load("data/lm/G_3_gram.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
del G.aux_labels
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Composing L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}, num_arcs: {LG.num_arcs}")
del LG.aux_labels
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Determinizing LG")
LG = k2.determinize(LG)
logging.info(
f"LG shape after k2.determinize: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info(
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set LG.properties to None
LG.__dict__["_properties"] = None
logging.info("Removing epsilons")
LG = k2.remove_epsilon(LG)
logging.info(
f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Connecting")
LG = k2.connect(LG)
logging.info(
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logging.info(f"LG properties: {LG.properties_str}")
# Possible properties is:
# "Valid|Nonempty|ArcSorted|EpsilonFree|MaybeAccessible|MaybeCoaccessible"
logging.info("Caution: LG is not deterministic!!!")
return LG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
out_filename = lang_dir / "LG.pt"
if out_filename.is_file():
logging.info(f"{out_filename} already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
LG = compile_LG(lang_dir)
logging.info(f"Saving LG to {out_filename}")
torch.save(LG.as_dict(), out_filename)
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,79 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./local/test_compile_lg.py
"""
from pathlib import Path
from typing import List
import k2
import sentencepiece as spm
import torch
lang_dir = Path("./data/lang_bpe_500")
def get_word_ids(word_table: k2.SymbolTable, s: str) -> List[int]:
"""
Args:
word_table:
Word symbol table.
s:
A string consisting of space(s) separated words.
Returns:
Return a list of word IDs.
"""
ans = []
for w in s.split():
ans.append(word_table[w])
return ans
def main():
assert lang_dir.exists(), f"{lang_dir} does not exist!"
LG = k2.Fsa.from_dict(torch.load(f"{lang_dir}/LG.pt", map_location="cpu"))
sp = spm.SentencePieceProcessor()
sp.load(f"{lang_dir}/bpe.model")
word_table = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
s = "HELLO WORLD"
token_ids = sp.encode(s)
token_fsa = k2.linear_fsa(token_ids)
fsa = k2.intersect(LG, token_fsa)
fsa = k2.connect(fsa)
print(k2.to_dot(fsa))
print(fsa.properties_str)
print(LG.properties_str)
# You can use https://dreampuf.github.io/GraphvizOnline/
# to visualize the output.
#
# You can see that the resulting fsa is not deterministic
# Note: LG is non-deterministic
#
# See https://shorturl.at/uIL69
# for visualization of the above fsa.
if __name__ == "__main__":
main()

View File

@ -17,8 +17,10 @@
from dataclasses import dataclass
from typing import Dict, List, Optional
import k2
import torch
from model import Transducer
from shallow_fusion import shallow_fusion
def greedy_search(
@ -111,6 +113,13 @@ class Hypothesis:
# It contains only one entry.
log_prob: torch.Tensor
# Used for shallow fusion
# The key of the dict is a state index into LG
# while the corresponding value is the LM score
# reaching this state.
# Note: The value tensor contains only a single entry
ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = (None,)
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
@ -149,6 +158,15 @@ class HypothesisList(object):
torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
if hyp.ngram_state_and_scores is not None:
for state, score in hyp.ngram_state_and_scores.items():
if (
state in old_hyp.ngram_state_and_scores
and score > old_hyp.ngram_state_and_scores[state]
):
old_hyp.ngram_state_and_scores[state] = score
else:
old_hyp.ngram_state_and_scores[state] = score
else:
self._data[key] = hyp
@ -318,6 +336,7 @@ def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
LG: Optional[k2.Fsa] = None,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
@ -328,9 +347,13 @@ def modified_beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
LG:
Optional. Used for shallow fusion.
Returns:
Return the decoded result.
"""
enable_shallow_fusion = LG is not None
ngram_lm_scale = 0.8
assert encoder_out.ndim == 3
@ -350,10 +373,19 @@ def modified_beam_search(
T = encoder_out.size(1)
B = HypothesisList()
if enable_shallow_fusion:
ngram_state_and_scores = {
0: torch.zeros(1, dtype=torch.float32, device=device)
}
else:
ngram_state_and_scores = None
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
ngram_state_and_scores=ngram_state_and_scores,
)
)
@ -411,9 +443,33 @@ def modified_beam_search(
new_token = topk_token_indexes[i]
if new_token != blank_id:
new_ys.append(new_token)
else:
ngram_state_and_scores = hyp.ngram_state_and_scores
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
if enable_shallow_fusion and new_token != blank_id:
ngram_state_and_scores = shallow_fusion(
LG, new_token, hyp.ngram_state_and_scores
)
if len(ngram_state_and_scores) == 0:
continue
max_ngram_score = max(ngram_state_and_scores.values())
new_log_prob += ngram_lm_scale * max_ngram_score
# TODO: Get the maximum scores in ngram_state_and_scores
# and add it to new_log_prob
new_hyp = Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
ngram_state_and_scores=ngram_state_and_scores,
)
B.add(new_hyp)
if len(B) == 0:
for h in A:
B.add(h)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks

View File

@ -40,8 +40,9 @@ import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
@ -131,6 +132,13 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--LG",
type=str,
help="""Path to LG.pt for shallow fusion.
Used only when --decoding-method is modified_beam_search.""",
)
return parser
@ -203,6 +211,7 @@ def decode_one_batch(
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
LG: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -225,6 +234,9 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
LG:
Optional. Used for shallow fusion. Used only when params.decoding_method
is modified_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -257,17 +269,24 @@ def decode_one_batch(
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
LG=LG,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
for h in hyps:
print(" ".join(h))
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
@ -280,6 +299,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
LG: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -292,6 +312,9 @@ def decode_dataset(
The neural model.
sp:
The BPE model.
LG:
Optional. Used for shallow fusion. Used only when params.decoding_method
is modified_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -320,6 +343,7 @@ def decode_dataset(
model=model,
sp=sp,
batch=batch,
LG=LG,
)
for name, hyps in hyps_dict.items():
@ -419,6 +443,21 @@ def main():
logging.info(f"Device: {device}")
if params.LG is not None:
assert (
params.decoding_method == "modified_beam_search"
), "--LG is used only when --decoding_method=modified_beam_search"
logging.info(f"Loading LG from {params.LG}")
LG = k2.Fsa.from_dict(torch.load(params.LG, map_location=device))
logging.info(f"LG properties: {LG.properties_str}")
logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}")
# If LG is created by local/compile_lg.py, then it should be epsilon
# free as well as arc sorted
assert "ArcSorted" in LG.properties_str
assert "EpsilonFree" in LG.properties_str
else:
LG = None
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
@ -467,6 +506,7 @@ def main():
params=params,
model=model,
sp=sp,
LG=LG,
)
save_results(

View File

@ -0,0 +1,78 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import k2
import torch
def shallow_fusion(
LG: k2.Fsa,
token: int,
state_and_scores: Dict[int, torch.Tensor],
) -> Dict[int, torch.Tensor]:
"""
Args:
LG:
An n-gram. It should be arc sorted and epsilon free.
token:
The input token ID.
state_and_scores:
The keys contain the current state we are in and the
values are the LM log_prob for reaching the corresponding
states from the start state.
Returns:
Return a new state_and_scores.
"""
row_splits = LG.arcs.row_splits(1)
arcs = LG.arcs.values()
current_states = list(state_and_scores.keys())
ans = dict()
for s in current_states:
labels_begin = row_splits[s]
labels_end = row_splits[s + 1]
labels = LG.labels[labels_begin:labels_end].contiguous()
# As LG is not deterministic, there may be multiple
# out-going arcs that with label equal to "token"
#
# Note: LG is arc sorted!
left = torch.bucketize(token, labels, right=False)
right = torch.bucketize(token, labels, right=True)
if left >= right:
# There are no out-going arcs from this state
# that have label equal to "token"
continue
# Now we have
# labels[i] == token
# for
# left <= i < right
for i in range(left, right):
i += labels_begin
next_state = arcs[i][1].item()
score = LG.scores[i]
if next_state not in ans:
ans[next_state] = score
else:
ans[next_state] = max(score, ans[next_state])
return ans