psd algorithm

This commit is contained in:
Guo Liyong 2022-07-14 00:01:28 +08:00
parent bc2882ddcc
commit 0a99ceb6ba
3 changed files with 219 additions and 9 deletions

View File

@ -47,10 +47,19 @@ def get_args():
""",
)
parser.add_argument(
"--h-graph",
type=str,
help="""one of ["H", "Trivial"]
H: k2.ctc_topo
Trivial: k2.trivial_graph
""",
)
return parser.parse_args()
def compile_HLG(lang_dir: str) -> k2.Fsa:
def compile_HLG(lang_dir: str, h_graph: str = "H") -> k2.Fsa:
"""
Args:
lang_dir:
@ -62,7 +71,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
if h_graph == "H":
H = k2.ctc_topo(max_token_id)
elif h_graph == "Trivial":
H = k2.trivial_graph(max_token_id - 1)
else:
raise ValueError(f"Unsupported h_graph: {h_graph}")
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G_3_gram.pt").is_file():
@ -138,15 +154,17 @@ def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
if (lang_dir / f"{args.h_graph}LG.pt").is_file():
logging.info(
f"{lang_dir}/{args.h_graph}LG.pt already exists - skipping"
)
return
logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
logging.info(f"Saving {args.h_graph}LG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/{args.h_graph}LG.pt")
if __name__ == "__main__":

View File

@ -14,16 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional
import k2
import torch
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence
from model import Transducer
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
from icefall.decode import get_lattice, Nbest, one_best_decoding
from icefall.utils import get_alignments, get_texts
def fast_beam_search_one_best(
@ -534,6 +536,8 @@ def greedy_search_batch(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
decoding_graph: Optional[k2.Fsa] = None,
ngram_rescoring: bool = False,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
@ -583,6 +587,18 @@ def greedy_search_batch(
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
if ngram_rescoring:
vocab_size = model.decoder.vocab_size
total_t = encoder_out.shape[0]
# cached all joiner outputs during greedy search,
# from which non-blank frames are selected before n-gram rescoring.
all_logits = torch.zeros([total_t, vocab_size], device=device)
# A flag indicating a frame is a blank frame or not.
# 0 for blank frame and 1 for non-blank frame.
# Used to select non-blank frames for n-gram rescoring.
non_blank_flag = torch.zeros([total_t], device=device)
offset = 0
for batch_size in batch_size_list:
start = offset
@ -600,7 +616,36 @@ def greedy_search_batch(
# logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
if ngram_rescoring:
all_logits[start:end] = logits
assert logits.ndim == 2, logits.shape
logits_argmax = logits.argmax(dim=1)
logits_softmax = logits.softmax(dim=1)
# detailed in below fuction verify_non_blank_logits.
selection_verification = True
# 0 for blank frame and 1 for non-blank frame.
non_blank_flag[start:end] = torch.where(
logits_argmax == blank_id, 0, 1
)
if False:
# In paper: https://arxiv.org/pdf/2101.06856.pdf
# A gama_blank threshold value is used to determinze blank frame.
# Currently, results are worse than baseline greedy_search
# and also very sensitive to gama_blank.
# (TODO): debug this later.
gama_blank = 0.50
non_blank_flag[start:end] = torch.where(
logits_softmax[:, 0] >= gama_blank, 0, 1
)
# function verify_non_blank_logits only works with logits_argmax == blank_id.
selection_verification = False
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
@ -624,6 +669,105 @@ def greedy_search_batch(
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
if not ngram_rescoring:
return ans
assert decoding_graph is not None
# Transform logits to shape [N, T, vocab_size] format to make it easier
# to select non-blank frames.
packed_all_logits = PackedSequence(
all_logits, torch.tensor(batch_size_list)
)
all_logits_unpacked, _ = pad_packed_sequence(
packed_all_logits, batch_first=True
)
# Transform non_blank_flag to shape [N, T]
packed_non_blank_flag = PackedSequence(
non_blank_flag, torch.tensor(batch_size_list)
)
non_blank_flag_unpacked, _ = pad_packed_sequence(
packed_non_blank_flag, batch_first=True
)
non_blank_logits_lens = torch.sum(non_blank_flag_unpacked, dim=1)
max_frame_to_rescore = non_blank_logits_lens.max()
non_blank_logits = torch.zeros(
[N, int(max_frame_to_rescore), vocab_size], device=device
)
# torch.index_select only acceptec a single dimension to index from.
# So we need generate non_blank_logits one by one.
# Maybe there is another efficient way to do this.
for i in range(N):
cur_non_blank_index = torch.where(non_blank_flag_unpacked[i, :] != 0)[0]
assert non_blank_logits_lens[i] == cur_non_blank_index.shape[0]
non_blank_logits[
i, : int(non_blank_logits_lens[i]), :
] = torch.index_select(
all_logits_unpacked[i, :], 0, cur_non_blank_index
)
def verify_non_blank_logits():
# A way to verify non_blank_logits are selected correctly from all_logits.
hyps_before_rescore = non_blank_logits.argmax(dim=2)
for i in range(N):
usi = unsorted_indices[i]
hyp_to_verify = hyps_before_rescore[usi][
: int(non_blank_logits_lens[usi])
].tolist()
assert ans[i] == hyp_to_verify
logging.info("Verified non-blank logits.")
# TODO: skip verification after we finally get a workable rescoring method.
if selection_verification:
verify_non_blank_logits()
# Split log_softmax into two seperate steps,
# so we cound do blank deweight in probability domain if needed.
logits_to_rescore_softmax = non_blank_logits.softmax(dim=2)
logits_to_rescore = logits_to_rescore_softmax.log()
# In paper: https://arxiv.org/pdf/2101.06856.pdf
# blank deweight is applied before non_blank frames selected.
# However, in current setup, that results in a higher WER.
# So just put this blank deweight before ngram rescoring.
# (TODO): debug this blank deweight issue.
blank_deweight = 100
logits_to_rescore[:, :, 0] -= blank_deweight
supervision_segments = torch.zeros([N, 3], dtype=torch.int32)
supervision_segments[:, 0] = torch.arange(0, N, dtype=torch.int32)
supervision_segments[:, 2] = non_blank_logits_lens.to(torch.int32)
lattice = get_lattice(
nnet_output=logits_to_rescore,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=20,
output_beam=8,
min_active_states=30,
max_active_states=1000,
subsampling_factor=1,
)
lm_weight = 0.3 # (TODO): tuning this.
lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight)
best_path = one_best_decoding(
lattice=lattice,
use_double_scores=True,
)
token_ids = get_alignments(best_path, "labels")
ans = []
for i in range(N):
usi = unsorted_indices[i]
ans.append(token_ids[usi][: int(non_blank_logits_lens[usi])])
return ans

View File

@ -136,6 +136,28 @@ def get_parser():
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--ngram-rescoring",
type=str2bool,
default=False,
help="Whether to use ngram_rescoring.",
)
parser.add_argument(
"--decoding-graph",
type=str,
default="trivial_graph",
help="one of [trivial_grpah, HLG, Trival_LG, LG]"
"used by greedy_search_batch with ngram-rescoring=True.",
)
parser.add_argument(
"--lang-dir",
type=str,
default="./data/lang_bpe_500/",
help="Path to decoding graphs",
)
parser.add_argument(
"--exp-dir",
type=str,
@ -293,6 +315,8 @@ def decode_one_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
decoding_graph=decoding_graph,
ngram_rescoring=params.ngram_rescoring,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -498,6 +522,10 @@ def main():
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
if params.ngram_rescoring:
params.suffix += "-ngram-rescoring"
params.suffix += f"-{params.decoding_graph}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@ -605,6 +633,26 @@ def main():
else:
decoding_graph = None
if params.ngram_rescoring and params.decoding_method == "greedy_search":
assert params.decoding_graph in [
"trivial_graph",
"HLG",
"Trivial_LG",
], f"Unsupported decoding graph {params.decoding_graph}"
if params.decoding_graph == "trivial_graph":
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = k2.Fsa.from_dict(
torch.load(
f"data/lang_bpe_500/{params.decoding_graph}.pt",
map_location=device,
)
)
decoding_graph.lm_scores = decoding_graph.scores.clone()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")