psd algorithm

This commit is contained in:
Guo Liyong 2022-07-14 00:01:28 +08:00
parent bc2882ddcc
commit 0a99ceb6ba
3 changed files with 219 additions and 9 deletions

View File

@ -47,10 +47,19 @@ def get_args():
""", """,
) )
parser.add_argument(
"--h-graph",
type=str,
help="""one of ["H", "Trivial"]
H: k2.ctc_topo
Trivial: k2.trivial_graph
""",
)
return parser.parse_args() return parser.parse_args()
def compile_HLG(lang_dir: str) -> k2.Fsa: def compile_HLG(lang_dir: str, h_graph: str = "H") -> k2.Fsa:
""" """
Args: Args:
lang_dir: lang_dir:
@ -62,7 +71,14 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
lexicon = Lexicon(lang_dir) lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens) max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
if h_graph == "H":
H = k2.ctc_topo(max_token_id)
elif h_graph == "Trivial":
H = k2.trivial_graph(max_token_id - 1)
else:
raise ValueError(f"Unsupported h_graph: {h_graph}")
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G_3_gram.pt").is_file(): if Path("data/lm/G_3_gram.pt").is_file():
@ -138,15 +154,17 @@ def main():
args = get_args() args = get_args()
lang_dir = Path(args.lang_dir) lang_dir = Path(args.lang_dir)
if (lang_dir / "HLG.pt").is_file(): if (lang_dir / f"{args.h_graph}LG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping") logging.info(
f"{lang_dir}/{args.h_graph}LG.pt already exists - skipping"
)
return return
logging.info(f"Processing {lang_dir}") logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir) HLG = compile_HLG(lang_dir)
logging.info(f"Saving HLG.pt to {lang_dir}") logging.info(f"Saving {args.h_graph}LG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") torch.save(HLG.as_dict(), f"{lang_dir}/{args.h_graph}LG.pt")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -14,16 +14,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import k2 import k2
import torch import torch
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence
from model import Transducer from model import Transducer
from icefall.decode import Nbest, one_best_decoding from icefall.decode import get_lattice, Nbest, one_best_decoding
from icefall.utils import get_texts from icefall.utils import get_alignments, get_texts
def fast_beam_search_one_best( def fast_beam_search_one_best(
@ -534,6 +536,8 @@ def greedy_search_batch(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
decoding_graph: Optional[k2.Fsa] = None,
ngram_rescoring: bool = False,
) -> List[List[int]]: ) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args: Args:
@ -583,6 +587,18 @@ def greedy_search_batch(
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
if ngram_rescoring:
vocab_size = model.decoder.vocab_size
total_t = encoder_out.shape[0]
# cached all joiner outputs during greedy search,
# from which non-blank frames are selected before n-gram rescoring.
all_logits = torch.zeros([total_t, vocab_size], device=device)
# A flag indicating a frame is a blank frame or not.
# 0 for blank frame and 1 for non-blank frame.
# Used to select non-blank frames for n-gram rescoring.
non_blank_flag = torch.zeros([total_t], device=device)
offset = 0 offset = 0
for batch_size in batch_size_list: for batch_size in batch_size_list:
start = offset start = offset
@ -600,7 +616,36 @@ def greedy_search_batch(
# logits'shape (batch_size, 1, 1, vocab_size) # logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
if ngram_rescoring:
all_logits[start:end] = logits
assert logits.ndim == 2, logits.shape
logits_argmax = logits.argmax(dim=1)
logits_softmax = logits.softmax(dim=1)
# detailed in below fuction verify_non_blank_logits.
selection_verification = True
# 0 for blank frame and 1 for non-blank frame.
non_blank_flag[start:end] = torch.where(
logits_argmax == blank_id, 0, 1
)
if False:
# In paper: https://arxiv.org/pdf/2101.06856.pdf
# A gama_blank threshold value is used to determinze blank frame.
# Currently, results are worse than baseline greedy_search
# and also very sensitive to gama_blank.
# (TODO): debug this later.
gama_blank = 0.50
non_blank_flag[start:end] = torch.where(
logits_softmax[:, 0] >= gama_blank, 0, 1
)
# function verify_non_blank_logits only works with logits_argmax == blank_id.
selection_verification = False
y = logits.argmax(dim=1).tolist() y = logits.argmax(dim=1).tolist()
emitted = False emitted = False
for i, v in enumerate(y): for i, v in enumerate(y):
@ -624,6 +669,105 @@ def greedy_search_batch(
for i in range(N): for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]]) ans.append(sorted_ans[unsorted_indices[i]])
if not ngram_rescoring:
return ans
assert decoding_graph is not None
# Transform logits to shape [N, T, vocab_size] format to make it easier
# to select non-blank frames.
packed_all_logits = PackedSequence(
all_logits, torch.tensor(batch_size_list)
)
all_logits_unpacked, _ = pad_packed_sequence(
packed_all_logits, batch_first=True
)
# Transform non_blank_flag to shape [N, T]
packed_non_blank_flag = PackedSequence(
non_blank_flag, torch.tensor(batch_size_list)
)
non_blank_flag_unpacked, _ = pad_packed_sequence(
packed_non_blank_flag, batch_first=True
)
non_blank_logits_lens = torch.sum(non_blank_flag_unpacked, dim=1)
max_frame_to_rescore = non_blank_logits_lens.max()
non_blank_logits = torch.zeros(
[N, int(max_frame_to_rescore), vocab_size], device=device
)
# torch.index_select only acceptec a single dimension to index from.
# So we need generate non_blank_logits one by one.
# Maybe there is another efficient way to do this.
for i in range(N):
cur_non_blank_index = torch.where(non_blank_flag_unpacked[i, :] != 0)[0]
assert non_blank_logits_lens[i] == cur_non_blank_index.shape[0]
non_blank_logits[
i, : int(non_blank_logits_lens[i]), :
] = torch.index_select(
all_logits_unpacked[i, :], 0, cur_non_blank_index
)
def verify_non_blank_logits():
# A way to verify non_blank_logits are selected correctly from all_logits.
hyps_before_rescore = non_blank_logits.argmax(dim=2)
for i in range(N):
usi = unsorted_indices[i]
hyp_to_verify = hyps_before_rescore[usi][
: int(non_blank_logits_lens[usi])
].tolist()
assert ans[i] == hyp_to_verify
logging.info("Verified non-blank logits.")
# TODO: skip verification after we finally get a workable rescoring method.
if selection_verification:
verify_non_blank_logits()
# Split log_softmax into two seperate steps,
# so we cound do blank deweight in probability domain if needed.
logits_to_rescore_softmax = non_blank_logits.softmax(dim=2)
logits_to_rescore = logits_to_rescore_softmax.log()
# In paper: https://arxiv.org/pdf/2101.06856.pdf
# blank deweight is applied before non_blank frames selected.
# However, in current setup, that results in a higher WER.
# So just put this blank deweight before ngram rescoring.
# (TODO): debug this blank deweight issue.
blank_deweight = 100
logits_to_rescore[:, :, 0] -= blank_deweight
supervision_segments = torch.zeros([N, 3], dtype=torch.int32)
supervision_segments[:, 0] = torch.arange(0, N, dtype=torch.int32)
supervision_segments[:, 2] = non_blank_logits_lens.to(torch.int32)
lattice = get_lattice(
nnet_output=logits_to_rescore,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=20,
output_beam=8,
min_active_states=30,
max_active_states=1000,
subsampling_factor=1,
)
lm_weight = 0.3 # (TODO): tuning this.
lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight)
best_path = one_best_decoding(
lattice=lattice,
use_double_scores=True,
)
token_ids = get_alignments(best_path, "labels")
ans = []
for i in range(N):
usi = unsorted_indices[i]
ans.append(token_ids[usi][: int(non_blank_logits_lens[usi])])
return ans return ans

View File

@ -136,6 +136,28 @@ def get_parser():
"`epoch` are loaded for averaging. ", "`epoch` are loaded for averaging. ",
) )
parser.add_argument(
"--ngram-rescoring",
type=str2bool,
default=False,
help="Whether to use ngram_rescoring.",
)
parser.add_argument(
"--decoding-graph",
type=str,
default="trivial_graph",
help="one of [trivial_grpah, HLG, Trival_LG, LG]"
"used by greedy_search_batch with ngram-rescoring=True.",
)
parser.add_argument(
"--lang-dir",
type=str,
default="./data/lang_bpe_500/",
help="Path to decoding graphs",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -293,6 +315,8 @@ def decode_one_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
decoding_graph=decoding_graph,
ngram_rescoring=params.ngram_rescoring,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -498,6 +522,10 @@ def main():
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "-use-averaged-model" params.suffix += "-use-averaged-model"
if params.ngram_rescoring:
params.suffix += "-ngram-rescoring"
params.suffix += f"-{params.decoding_graph}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")
@ -605,6 +633,26 @@ def main():
else: else:
decoding_graph = None decoding_graph = None
if params.ngram_rescoring and params.decoding_method == "greedy_search":
assert params.decoding_graph in [
"trivial_graph",
"HLG",
"Trivial_LG",
], f"Unsupported decoding graph {params.decoding_graph}"
if params.decoding_graph == "trivial_graph":
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = k2.Fsa.from_dict(
torch.load(
f"data/lang_bpe_500/{params.decoding_graph}.pt",
map_location=device,
)
)
decoding_graph.lm_scores = decoding_graph.scores.clone()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")