Add LG decoding (#277)

* Add LG decoding

* Add log weight pushing

* Minor fixes
This commit is contained in:
Wei Kang 2022-04-19 17:23:46 +08:00 committed by GitHub
parent 5fe58de43c
commit 021c79824e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 344 additions and 19 deletions

View File

@ -0,0 +1,141 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates LG from
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt
The generated LG is saved in $lang_dir/LG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_LG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
Return:
An FSA representing LG.
"""
lexicon = Lexicon(lang_dir)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G_3_gram.pt").is_file():
logging.info("Loading pre-compiled G_3_gram")
d = torch.load("data/lm/G_3_gram.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set LG.properties to None
LG.__dict__["_properties"] = None
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
return LG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "LG.pt").is_file():
logging.info(f"{lang_dir}/LG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
LG = compile_LG(lang_dir)
logging.info(f"Saving LG.pt to {lang_dir}")
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -242,3 +242,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
./local/compile_hlg.py --lang-dir $lang_dir ./local/compile_hlg.py --lang-dir $lang_dir
done done
fi fi
# Compile LG for RNN-T fast_beam_search decoding
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Compile LG"
./local/compile_lg.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
./local/compile_lg.py --lang-dir $lang_dir
done
fi

View File

@ -22,7 +22,7 @@ import k2
import torch import torch
from model import Transducer from model import Transducer
from icefall.decode import one_best_decoding from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts from icefall.utils import get_texts
@ -34,6 +34,7 @@ def fast_beam_search(
beam: float, beam: float,
max_states: int, max_states: int,
max_contexts: int, max_contexts: int,
use_max: bool = False,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -53,6 +54,9 @@ def fast_beam_search(
Max states per stream per frame. Max states per stream per frame.
max_contexts: max_contexts:
Max contexts pre stream per frame. Max contexts pre stream per frame.
use_max:
True to use max operation to select the hypothesis with the largest
log_prob when there are duplicate hypotheses; False to use log-add.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -104,9 +108,67 @@ def fast_beam_search(
decoding_streams.terminate_and_flush_to_streams() decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist()) lattice = decoding_streams.format_output(encoder_out_lens.tolist())
if use_max:
best_path = one_best_decoding(lattice) best_path = one_best_decoding(lattice)
hyps = get_texts(best_path) hyps = get_texts(best_path)
return hyps return hyps
else:
num_paths = 200
use_double_scores = True
nbest_scale = 0.8
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# The following code is modified from nbest.intersect()
word_fsa = k2.invert(nbest.fsa)
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa.scores.zero_()
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
path_to_utt_map = nbest.shape.row_ids(1)
if hasattr(lattice, "aux_labels"):
# lattice has token IDs as labels and word IDs as aux_labels.
# inv_lattice has word IDs as labels and token IDs as aux_labels
inv_lattice = k2.invert(lattice)
inv_lattice = k2.arc_sort(inv_lattice)
else:
inv_lattice = k2.arc_sort(lattice)
if inv_lattice.shape[0] == 1:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=torch.zeros_like(path_to_utt_map),
sorted_match_a=True,
)
else:
path_lattice = k2.intersect_device(
inv_lattice,
word_fsa_with_epsilon_loops,
b_to_a_map=path_to_utt_map,
sorted_match_a=True,
)
# path_lattice has word IDs as labels and token IDs as aux_labels
path_lattice = k2.top_sort(k2.connect(path_lattice))
tot_scores = path_lattice.get_tot_scores(
use_double_scores=use_double_scores, log_semiring=True
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
best_hyp_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
hyps = get_texts(best_path)
return hyps
def greedy_search( def greedy_search(
@ -280,7 +342,7 @@ class HypothesisList(object):
def data(self) -> Dict[str, Hypothesis]: def data(self) -> Dict[str, Hypothesis]:
return self._data return self._data
def add(self, hyp: Hypothesis) -> None: def add(self, hyp: Hypothesis, use_max: bool = False) -> None:
"""Add a Hypothesis to `self`. """Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using If `hyp` already exists in `self`, its probability is updated using
@ -289,10 +351,17 @@ class HypothesisList(object):
Args: Args:
hyp: hyp:
The hypothesis to be added. The hypothesis to be added.
use_max:
True to select the hypothesis with the larger log_prob in case there
already exists a hypothesis whose `ys` equals to `hyp.ys`.
False to use log_add.
""" """
key = hyp.key key = hyp.key
if key in self: if key in self:
old_hyp = self._data[key] # shallow copy old_hyp = self._data[key] # shallow copy
if use_max:
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
else:
torch.logaddexp( torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
) )
@ -403,6 +472,7 @@ def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
use_max: bool = False,
) -> List[List[int]]: ) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
@ -413,6 +483,9 @@ def modified_beam_search(
Output from the encoder. Its shape is (N, T, C). Output from the encoder. Its shape is (N, T, C).
beam: beam:
Number of active paths during the beam search. Number of active paths during the beam search.
use_max:
True to use max operation to select the hypothesis with the largest
log_prob when there are duplicate hypotheses; False to use log-add.
Returns: Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance. for the i-th utterance.
@ -432,7 +505,8 @@ def modified_beam_search(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
) ),
use_max=use_max,
) )
for t in range(T): for t in range(T):
@ -517,6 +591,7 @@ def _deprecated_modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
use_max: bool = False,
) -> List[int]: ) -> List[int]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -532,6 +607,9 @@ def _deprecated_modified_beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam: beam:
Beam size. Beam size.
use_max:
True to use max operation to select the hypothesis with the largest
log_prob when there are duplicate hypotheses; False to use log-add.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -553,7 +631,8 @@ def _deprecated_modified_beam_search(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device), log_prob=torch.zeros(1, dtype=torch.float32, device=device),
) ),
use_max=use_max,
) )
for t in range(T): for t in range(T):
@ -611,7 +690,7 @@ def _deprecated_modified_beam_search(
new_ys.append(new_token) new_ys.append(new_token)
new_log_prob = topk_log_probs[i] new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp) B.add(new_hyp, use_max=use_max)
best_hyp = B.get_most_probable(length_norm=True) best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
@ -623,6 +702,7 @@ def beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
use_max: bool = False,
) -> List[int]: ) -> List[int]:
""" """
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -636,6 +716,9 @@ def beam_search(
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam: beam:
Beam size. Beam size.
use_max:
True to use max operation to select the hypothesis with the largest
log_prob when there are duplicate hypotheses; False to use log-add.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -661,7 +744,9 @@ def beam_search(
t = 0 t = 0
B = HypothesisList() B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) B.add(
Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max
)
max_sym_per_utt = 20000 max_sym_per_utt = 20000
@ -720,7 +805,10 @@ def beam_search(
new_y_star_log_prob = y_star.log_prob + skip_log_prob new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys # ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) B.add(
Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob),
use_max=use_max,
)
# Second, process other non-blank labels # Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1) values, indices = log_prob.topk(beam + 1)
@ -729,7 +817,10 @@ def beam_search(
continue continue
new_ys = y_star.ys + [i] new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v new_log_prob = y_star.log_prob + v
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) A.add(
Hypothesis(ys=new_ys, log_prob=new_log_prob),
use_max=use_max,
)
# Check whether B contains more than "beam" elements more probable # Check whether B contains more than "beam" elements more probable
# than the most probable in A # than the most probable in A

View File

@ -53,6 +53,19 @@ Usage:
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
--max-states 8 --max-states 8
(5) fast beam search using LG
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--use-LG True \
--use-max False \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 8 \
--max-contexts 8 \
--max-states 64
""" """
@ -81,10 +94,12 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool,
write_error_stats, write_error_stats,
) )
@ -136,6 +151,13 @@ def get_parser():
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@ -167,6 +189,36 @@ def get_parser():
Used only when --decoding-method is fast_beam_search""", Used only when --decoding-method is fast_beam_search""",
) )
parser.add_argument(
"--use-LG",
type=str2bool,
default=False,
help="""Whether to use an LG graph for FSA-based beam search.
Used only when --decoding_method is fast_beam_search. If setting true,
it assumes there is an LG.pt file in lang_dir.""",
)
parser.add_argument(
"--use-max",
type=str2bool,
default=False,
help="""If True, use max-op to select the hypothesis that have the
max log_prob in case of duplicate hypotheses.
If False, use log_add.
Used only for beam_search, modified_beam_search, and fast_beam_search
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
@ -206,6 +258,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -229,6 +282,8 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search.
@ -260,7 +315,12 @@ def decode_one_batch(
beam=params.beam, beam=params.beam,
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
use_max=params.use_max,
) )
if params.use_LG:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
else:
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif ( elif (
@ -278,6 +338,7 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
beam=params.beam_size, beam=params.beam_size,
use_max=params.use_max,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -299,6 +360,7 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
use_max=params.use_max,
) )
else: else:
raise ValueError( raise ValueError(
@ -325,6 +387,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -338,6 +401,8 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search. only when --decoding_method is fast_beam_search.
@ -368,8 +433,9 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
batch=batch, batch=batch,
word_table=word_table,
decoding_graph=decoding_graph,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -460,13 +526,16 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-use-LG-{params.use_LG}"
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
params.suffix += f"-use-max-{params.use_max}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}" f"-{params.decoding_method}-beam-size-{params.beam_size}"
) )
params.suffix += f"-use-max-{params.use_max}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -527,9 +596,21 @@ def main():
model.device = device model.device = device
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) if params.use_LG:
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
decoding_graph = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/LG.pt", map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -551,6 +632,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )