add self loop to L

This commit is contained in:
Guo Liyong 2022-07-16 01:19:04 +08:00
parent 0a99ceb6ba
commit 473efcd531
3 changed files with 22 additions and 35 deletions

View File

@ -538,6 +538,7 @@ def greedy_search_batch(
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
ngram_rescoring: bool = False, ngram_rescoring: bool = False,
gamma_blank: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args: Args:
@ -624,27 +625,12 @@ def greedy_search_batch(
logits_argmax = logits.argmax(dim=1) logits_argmax = logits.argmax(dim=1)
logits_softmax = logits.softmax(dim=1) logits_softmax = logits.softmax(dim=1)
# detailed in below fuction verify_non_blank_logits.
selection_verification = True
# 0 for blank frame and 1 for non-blank frame. # 0 for blank frame and 1 for non-blank frame.
non_blank_flag[start:end] = torch.where( non_blank_flag[start:end] = torch.where(
logits_argmax == blank_id, 0, 1 logits_softmax[:, 0] >= gamma_blank, 0, 1
) )
if False:
# In paper: https://arxiv.org/pdf/2101.06856.pdf
# A gama_blank threshold value is used to determinze blank frame.
# Currently, results are worse than baseline greedy_search
# and also very sensitive to gama_blank.
# (TODO): debug this later.
gama_blank = 0.50
non_blank_flag[start:end] = torch.where(
logits_softmax[:, 0] >= gama_blank, 0, 1
)
# function verify_non_blank_logits only works with logits_argmax == blank_id.
selection_verification = False
y = logits.argmax(dim=1).tolist() y = logits.argmax(dim=1).tolist()
emitted = False emitted = False
@ -710,21 +696,10 @@ def greedy_search_batch(
all_logits_unpacked[i, :], 0, cur_non_blank_index all_logits_unpacked[i, :], 0, cur_non_blank_index
) )
def verify_non_blank_logits():
# A way to verify non_blank_logits are selected correctly from all_logits.
hyps_before_rescore = non_blank_logits.argmax(dim=2)
for i in range(N):
usi = unsorted_indices[i]
hyp_to_verify = hyps_before_rescore[usi][
: int(non_blank_logits_lens[usi])
].tolist()
assert ans[i] == hyp_to_verify
logging.info("Verified non-blank logits.")
# TODO: skip verification after we finally get a workable rescoring method.
if selection_verification:
verify_non_blank_logits()
number_selected_frames = non_blank_flag.sum()
logging.info(f"{number_selected_frames} are selected out of {total_t} frames")
# Split log_softmax into two seperate steps, # Split log_softmax into two seperate steps,
# so we cound do blank deweight in probability domain if needed. # so we cound do blank deweight in probability domain if needed.
logits_to_rescore_softmax = non_blank_logits.softmax(dim=2) logits_to_rescore_softmax = non_blank_logits.softmax(dim=2)
@ -736,7 +711,7 @@ def greedy_search_batch(
# So just put this blank deweight before ngram rescoring. # So just put this blank deweight before ngram rescoring.
# (TODO): debug this blank deweight issue. # (TODO): debug this blank deweight issue.
blank_deweight = 100 blank_deweight = 0.0
logits_to_rescore[:, :, 0] -= blank_deweight logits_to_rescore[:, :, 0] -= blank_deweight
supervision_segments = torch.zeros([N, 3], dtype=torch.int32) supervision_segments = torch.zeros([N, 3], dtype=torch.int32)
@ -754,7 +729,7 @@ def greedy_search_batch(
subsampling_factor=1, subsampling_factor=1,
) )
lm_weight = 0.3 # (TODO): tuning this. lm_weight = 0.5 # (TODO): tuning this.
lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight) lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight)
best_path = one_best_decoding( best_path = one_best_decoding(
@ -762,7 +737,7 @@ def greedy_search_batch(
use_double_scores=True, use_double_scores=True,
) )
token_ids = get_alignments(best_path, "labels") token_ids = get_alignments(best_path, "labels", remove_zero_blank=True)
ans = [] ans = []
for i in range(N): for i in range(N):

View File

@ -241,6 +241,12 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--gamma-blank",
type=int,
default=1.0,
)
return parser return parser
@ -317,6 +323,7 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
ngram_rescoring=params.ngram_rescoring, ngram_rescoring=params.ngram_rescoring,
gamma_blank=params.gamma_blank,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -525,6 +532,7 @@ def main():
if params.ngram_rescoring: if params.ngram_rescoring:
params.suffix += "-ngram-rescoring" params.suffix += "-ngram-rescoring"
params.suffix += f"-{params.decoding_graph}" params.suffix += f"-{params.decoding_graph}"
params.suffix += f"-gamma_blank-{params.gamma_blank}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")
@ -636,8 +644,7 @@ def main():
if params.ngram_rescoring and params.decoding_method == "greedy_search": if params.ngram_rescoring and params.decoding_method == "greedy_search":
assert params.decoding_graph in [ assert params.decoding_graph in [
"trivial_graph", "trivial_graph",
"HLG", "L",
"Trivial_LG",
], f"Unsupported decoding graph {params.decoding_graph}" ], f"Unsupported decoding graph {params.decoding_graph}"
if params.decoding_graph == "trivial_graph": if params.decoding_graph == "trivial_graph":
decoding_graph = k2.trivial_graph( decoding_graph = k2.trivial_graph(
@ -650,6 +657,7 @@ def main():
map_location=device, map_location=device,
) )
) )
decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
decoding_graph.lm_scores = decoding_graph.scores.clone() decoding_graph.lm_scores = decoding_graph.scores.clone()

View File

@ -236,7 +236,9 @@ def get_texts(
return aux_labels.tolist() return aux_labels.tolist()
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]: def get_alignments(
best_paths: k2.Fsa, kind: str, remove_zero_blank: bool = False
) -> List[List[int]]:
"""Extract labels or aux_labels from the best-path FSAs. """Extract labels or aux_labels from the best-path FSAs.
Args: Args:
@ -272,6 +274,8 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
token_shape, getattr(best_paths, kind).contiguous() token_shape, getattr(best_paths, kind).contiguous()
) )
tokens = tokens.remove_values_eq(-1) tokens = tokens.remove_values_eq(-1)
if remove_zero_blank:
tokens = tokens.remove_values_eq(0)
return tokens.tolist() return tokens.tolist()