mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Add LG decoding (#277)
* Add LG decoding * Add log weight pushing * Minor fixes
This commit is contained in:
parent
5fe58de43c
commit
021c79824e
141
egs/librispeech/ASR/local/compile_lg.py
Executable file
141
egs/librispeech/ASR/local/compile_lg.py
Executable file
@ -0,0 +1,141 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes as input lang_dir and generates LG from
|
||||||
|
|
||||||
|
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||||
|
|
||||||
|
Caution: We use a lexicon that contains disambiguation symbols
|
||||||
|
|
||||||
|
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
||||||
|
|
||||||
|
The generated LG is saved in $lang_dir/LG.pt
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def compile_LG(lang_dir: str) -> k2.Fsa:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lang_dir:
|
||||||
|
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
An FSA representing LG.
|
||||||
|
"""
|
||||||
|
lexicon = Lexicon(lang_dir)
|
||||||
|
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||||
|
|
||||||
|
if Path("data/lm/G_3_gram.pt").is_file():
|
||||||
|
logging.info("Loading pre-compiled G_3_gram")
|
||||||
|
d = torch.load("data/lm/G_3_gram.pt")
|
||||||
|
G = k2.Fsa.from_dict(d)
|
||||||
|
else:
|
||||||
|
logging.info("Loading G_3_gram.fst.txt")
|
||||||
|
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||||
|
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||||
|
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
|
||||||
|
|
||||||
|
first_token_disambig_id = lexicon.token_table["#0"]
|
||||||
|
first_word_disambig_id = lexicon.word_table["#0"]
|
||||||
|
|
||||||
|
L = k2.arc_sort(L)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
|
||||||
|
logging.info("Intersecting L and G")
|
||||||
|
LG = k2.compose(L, G)
|
||||||
|
logging.info(f"LG shape: {LG.shape}")
|
||||||
|
|
||||||
|
logging.info("Connecting LG")
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||||
|
|
||||||
|
logging.info(type(LG.aux_labels))
|
||||||
|
logging.info("Determinizing LG")
|
||||||
|
|
||||||
|
LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing)
|
||||||
|
logging.info(type(LG.aux_labels))
|
||||||
|
|
||||||
|
logging.info("Connecting LG after k2.determinize")
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
|
||||||
|
logging.info("Removing disambiguation symbols on LG")
|
||||||
|
|
||||||
|
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||||
|
# See https://github.com/k2-fsa/k2/issues/874
|
||||||
|
# for why we need to set LG.properties to None
|
||||||
|
LG.__dict__["_properties"] = None
|
||||||
|
|
||||||
|
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||||
|
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||||
|
|
||||||
|
LG = k2.remove_epsilon(LG)
|
||||||
|
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||||
|
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||||
|
|
||||||
|
logging.info("Arc sorting LG")
|
||||||
|
LG = k2.arc_sort(LG)
|
||||||
|
|
||||||
|
return LG
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
if (lang_dir / "LG.pt").is_file():
|
||||||
|
logging.info(f"{lang_dir}/LG.pt already exists - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info(f"Processing {lang_dir}")
|
||||||
|
|
||||||
|
LG = compile_LG(lang_dir)
|
||||||
|
logging.info(f"Saving LG.pt to {lang_dir}")
|
||||||
|
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
@ -242,3 +242,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
|||||||
./local/compile_hlg.py --lang-dir $lang_dir
|
./local/compile_hlg.py --lang-dir $lang_dir
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Compile LG for RNN-T fast_beam_search decoding
|
||||||
|
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||||
|
log "Stage 10: Compile LG"
|
||||||
|
./local/compile_lg.py --lang-dir data/lang_phone
|
||||||
|
|
||||||
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
|
lang_dir=data/lang_bpe_${vocab_size}
|
||||||
|
./local/compile_lg.py --lang-dir $lang_dir
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
@ -22,7 +22,7 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
from icefall.decode import one_best_decoding
|
from icefall.decode import Nbest, one_best_decoding
|
||||||
from icefall.utils import get_texts
|
from icefall.utils import get_texts
|
||||||
|
|
||||||
|
|
||||||
@ -34,6 +34,7 @@ def fast_beam_search(
|
|||||||
beam: float,
|
beam: float,
|
||||||
max_states: int,
|
max_states: int,
|
||||||
max_contexts: int,
|
max_contexts: int,
|
||||||
|
use_max: bool = False,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -53,6 +54,9 @@ def fast_beam_search(
|
|||||||
Max states per stream per frame.
|
Max states per stream per frame.
|
||||||
max_contexts:
|
max_contexts:
|
||||||
Max contexts pre stream per frame.
|
Max contexts pre stream per frame.
|
||||||
|
use_max:
|
||||||
|
True to use max operation to select the hypothesis with the largest
|
||||||
|
log_prob when there are duplicate hypotheses; False to use log-add.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
Return the decoded result.
|
||||||
"""
|
"""
|
||||||
@ -104,9 +108,67 @@ def fast_beam_search(
|
|||||||
decoding_streams.terminate_and_flush_to_streams()
|
decoding_streams.terminate_and_flush_to_streams()
|
||||||
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||||
|
|
||||||
best_path = one_best_decoding(lattice)
|
if use_max:
|
||||||
hyps = get_texts(best_path)
|
best_path = one_best_decoding(lattice)
|
||||||
return hyps
|
hyps = get_texts(best_path)
|
||||||
|
return hyps
|
||||||
|
else:
|
||||||
|
num_paths = 200
|
||||||
|
use_double_scores = True
|
||||||
|
nbest_scale = 0.8
|
||||||
|
|
||||||
|
nbest = Nbest.from_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
nbest_scale=nbest_scale,
|
||||||
|
)
|
||||||
|
# The following code is modified from nbest.intersect()
|
||||||
|
word_fsa = k2.invert(nbest.fsa)
|
||||||
|
if hasattr(lattice, "aux_labels"):
|
||||||
|
# delete token IDs as it is not needed
|
||||||
|
del word_fsa.aux_labels
|
||||||
|
word_fsa.scores.zero_()
|
||||||
|
|
||||||
|
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
|
||||||
|
path_to_utt_map = nbest.shape.row_ids(1)
|
||||||
|
|
||||||
|
if hasattr(lattice, "aux_labels"):
|
||||||
|
# lattice has token IDs as labels and word IDs as aux_labels.
|
||||||
|
# inv_lattice has word IDs as labels and token IDs as aux_labels
|
||||||
|
inv_lattice = k2.invert(lattice)
|
||||||
|
inv_lattice = k2.arc_sort(inv_lattice)
|
||||||
|
else:
|
||||||
|
inv_lattice = k2.arc_sort(lattice)
|
||||||
|
|
||||||
|
if inv_lattice.shape[0] == 1:
|
||||||
|
path_lattice = k2.intersect_device(
|
||||||
|
inv_lattice,
|
||||||
|
word_fsa_with_epsilon_loops,
|
||||||
|
b_to_a_map=torch.zeros_like(path_to_utt_map),
|
||||||
|
sorted_match_a=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
path_lattice = k2.intersect_device(
|
||||||
|
inv_lattice,
|
||||||
|
word_fsa_with_epsilon_loops,
|
||||||
|
b_to_a_map=path_to_utt_map,
|
||||||
|
sorted_match_a=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# path_lattice has word IDs as labels and token IDs as aux_labels
|
||||||
|
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
||||||
|
|
||||||
|
tot_scores = path_lattice.get_tot_scores(
|
||||||
|
use_double_scores=use_double_scores, log_semiring=True
|
||||||
|
)
|
||||||
|
|
||||||
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
|
best_hyp_indexes = ragged_tot_scores.argmax()
|
||||||
|
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes)
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
@ -280,7 +342,7 @@ class HypothesisList(object):
|
|||||||
def data(self) -> Dict[str, Hypothesis]:
|
def data(self) -> Dict[str, Hypothesis]:
|
||||||
return self._data
|
return self._data
|
||||||
|
|
||||||
def add(self, hyp: Hypothesis) -> None:
|
def add(self, hyp: Hypothesis, use_max: bool = False) -> None:
|
||||||
"""Add a Hypothesis to `self`.
|
"""Add a Hypothesis to `self`.
|
||||||
|
|
||||||
If `hyp` already exists in `self`, its probability is updated using
|
If `hyp` already exists in `self`, its probability is updated using
|
||||||
@ -289,13 +351,20 @@ class HypothesisList(object):
|
|||||||
Args:
|
Args:
|
||||||
hyp:
|
hyp:
|
||||||
The hypothesis to be added.
|
The hypothesis to be added.
|
||||||
|
use_max:
|
||||||
|
True to select the hypothesis with the larger log_prob in case there
|
||||||
|
already exists a hypothesis whose `ys` equals to `hyp.ys`.
|
||||||
|
False to use log_add.
|
||||||
"""
|
"""
|
||||||
key = hyp.key
|
key = hyp.key
|
||||||
if key in self:
|
if key in self:
|
||||||
old_hyp = self._data[key] # shallow copy
|
old_hyp = self._data[key] # shallow copy
|
||||||
torch.logaddexp(
|
if use_max:
|
||||||
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
|
||||||
)
|
else:
|
||||||
|
torch.logaddexp(
|
||||||
|
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self._data[key] = hyp
|
self._data[key] = hyp
|
||||||
|
|
||||||
@ -403,6 +472,7 @@ def modified_beam_search(
|
|||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
|
use_max: bool = False,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
|
|
||||||
@ -413,6 +483,9 @@ def modified_beam_search(
|
|||||||
Output from the encoder. Its shape is (N, T, C).
|
Output from the encoder. Its shape is (N, T, C).
|
||||||
beam:
|
beam:
|
||||||
Number of active paths during the beam search.
|
Number of active paths during the beam search.
|
||||||
|
use_max:
|
||||||
|
True to use max operation to select the hypothesis with the largest
|
||||||
|
log_prob when there are duplicate hypotheses; False to use log-add.
|
||||||
Returns:
|
Returns:
|
||||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||||
for the i-th utterance.
|
for the i-th utterance.
|
||||||
@ -432,7 +505,8 @@ def modified_beam_search(
|
|||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[blank_id] * context_size,
|
ys=[blank_id] * context_size,
|
||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
)
|
),
|
||||||
|
use_max=use_max,
|
||||||
)
|
)
|
||||||
|
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
@ -517,6 +591,7 @@ def _deprecated_modified_beam_search(
|
|||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
|
use_max: bool = False,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -532,6 +607,9 @@ def _deprecated_modified_beam_search(
|
|||||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||||
beam:
|
beam:
|
||||||
Beam size.
|
Beam size.
|
||||||
|
use_max:
|
||||||
|
True to use max operation to select the hypothesis with the largest
|
||||||
|
log_prob when there are duplicate hypotheses; False to use log-add.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
Return the decoded result.
|
||||||
"""
|
"""
|
||||||
@ -553,12 +631,13 @@ def _deprecated_modified_beam_search(
|
|||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[blank_id] * context_size,
|
ys=[blank_id] * context_size,
|
||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
)
|
),
|
||||||
|
use_max=use_max,
|
||||||
)
|
)
|
||||||
|
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
current_encoder_out = encoder_out[:, t:t + 1, :].unsqueeze(2)
|
||||||
# current_encoder_out is of shape (1, 1, 1, encoder_out_dim)
|
# current_encoder_out is of shape (1, 1, 1, encoder_out_dim)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
A = list(B)
|
A = list(B)
|
||||||
@ -611,7 +690,7 @@ def _deprecated_modified_beam_search(
|
|||||||
new_ys.append(new_token)
|
new_ys.append(new_token)
|
||||||
new_log_prob = topk_log_probs[i]
|
new_log_prob = topk_log_probs[i]
|
||||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||||
B.add(new_hyp)
|
B.add(new_hyp, use_max=use_max)
|
||||||
|
|
||||||
best_hyp = B.get_most_probable(length_norm=True)
|
best_hyp = B.get_most_probable(length_norm=True)
|
||||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||||
@ -623,6 +702,7 @@ def beam_search(
|
|||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
|
use_max: bool = False,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
||||||
@ -636,6 +716,9 @@ def beam_search(
|
|||||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||||
beam:
|
beam:
|
||||||
Beam size.
|
Beam size.
|
||||||
|
use_max:
|
||||||
|
True to use max operation to select the hypothesis with the largest
|
||||||
|
log_prob when there are duplicate hypotheses; False to use log-add.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
Return the decoded result.
|
||||||
"""
|
"""
|
||||||
@ -661,7 +744,9 @@ def beam_search(
|
|||||||
t = 0
|
t = 0
|
||||||
|
|
||||||
B = HypothesisList()
|
B = HypothesisList()
|
||||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
|
B.add(
|
||||||
|
Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max
|
||||||
|
)
|
||||||
|
|
||||||
max_sym_per_utt = 20000
|
max_sym_per_utt = 20000
|
||||||
|
|
||||||
@ -720,7 +805,10 @@ def beam_search(
|
|||||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||||
|
|
||||||
# ys[:] returns a copy of ys
|
# ys[:] returns a copy of ys
|
||||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
B.add(
|
||||||
|
Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob),
|
||||||
|
use_max=use_max,
|
||||||
|
)
|
||||||
|
|
||||||
# Second, process other non-blank labels
|
# Second, process other non-blank labels
|
||||||
values, indices = log_prob.topk(beam + 1)
|
values, indices = log_prob.topk(beam + 1)
|
||||||
@ -729,7 +817,10 @@ def beam_search(
|
|||||||
continue
|
continue
|
||||||
new_ys = y_star.ys + [i]
|
new_ys = y_star.ys + [i]
|
||||||
new_log_prob = y_star.log_prob + v
|
new_log_prob = y_star.log_prob + v
|
||||||
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
A.add(
|
||||||
|
Hypothesis(ys=new_ys, log_prob=new_log_prob),
|
||||||
|
use_max=use_max,
|
||||||
|
)
|
||||||
|
|
||||||
# Check whether B contains more than "beam" elements more probable
|
# Check whether B contains more than "beam" elements more probable
|
||||||
# than the most probable in A
|
# than the most probable in A
|
||||||
|
@ -53,6 +53,19 @@ Usage:
|
|||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
|
|
||||||
|
(5) fast beam search using LG
|
||||||
|
./pruned_transducer_stateless/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
|
--use-LG True \
|
||||||
|
--use-max False \
|
||||||
|
--max-duration 1500 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 8 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -81,10 +94,12 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -136,6 +151,13 @@ def get_parser():
|
|||||||
help="Path to the BPE model",
|
help="Path to the BPE model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500",
|
||||||
|
help="The lang dir containing word table and LG graph",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoding-method",
|
"--decoding-method",
|
||||||
type=str,
|
type=str,
|
||||||
@ -167,6 +189,36 @@ def get_parser():
|
|||||||
Used only when --decoding-method is fast_beam_search""",
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-LG",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use an LG graph for FSA-based beam search.
|
||||||
|
Used only when --decoding_method is fast_beam_search. If setting true,
|
||||||
|
it assumes there is an LG.pt file in lang_dir.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-max",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""If True, use max-op to select the hypothesis that have the
|
||||||
|
max log_prob in case of duplicate hypotheses.
|
||||||
|
If False, use log_add.
|
||||||
|
Used only for beam_search, modified_beam_search, and fast_beam_search
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help="""
|
||||||
|
Used only when --decoding_method is fast_beam_search.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-contexts",
|
"--max-contexts",
|
||||||
type=int,
|
type=int,
|
||||||
@ -206,6 +258,7 @@ def decode_one_batch(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
@ -229,6 +282,8 @@ def decode_one_batch(
|
|||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
only when --decoding_method is fast_beam_search.
|
only when --decoding_method is fast_beam_search.
|
||||||
@ -260,9 +315,14 @@ def decode_one_batch(
|
|||||||
beam=params.beam,
|
beam=params.beam,
|
||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
|
use_max=params.use_max,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
if params.use_LG:
|
||||||
hyps.append(hyp.split())
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append([word_table[i] for i in hyp])
|
||||||
|
else:
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
elif (
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
@ -278,6 +338,7 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
use_max=params.use_max,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -299,6 +360,7 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
use_max=params.use_max,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -325,6 +387,7 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
@ -338,6 +401,8 @@ def decode_dataset(
|
|||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
sp:
|
||||||
The BPE model.
|
The BPE model.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
only when --decoding_method is fast_beam_search.
|
only when --decoding_method is fast_beam_search.
|
||||||
@ -368,8 +433,9 @@ def decode_dataset(
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
word_table=word_table,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -460,13 +526,16 @@ def main():
|
|||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-use-LG-{params.use_LG}"
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
params.suffix += f"-use-max-{params.use_max}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += (
|
params.suffix += (
|
||||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
)
|
)
|
||||||
|
params.suffix += f"-use-max-{params.use_max}"
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -527,9 +596,21 @@ def main():
|
|||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
if params.use_LG:
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
word_table = lexicon.word_table
|
||||||
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
|
torch.load(f"{params.lang_dir}/LG.pt", map_location=device)
|
||||||
|
)
|
||||||
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
|
else:
|
||||||
|
word_table = None
|
||||||
|
decoding_graph = k2.trivial_graph(
|
||||||
|
params.vocab_size - 1, device=device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
word_table = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
@ -551,6 +632,7 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
|
word_table=word_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user