add self loop to L

This commit is contained in:
Guo Liyong 2022-07-16 01:19:04 +08:00
parent 0a99ceb6ba
commit 473efcd531
3 changed files with 22 additions and 35 deletions

View File

@ -538,6 +538,7 @@ def greedy_search_batch(
encoder_out_lens: torch.Tensor,
decoding_graph: Optional[k2.Fsa] = None,
ngram_rescoring: bool = False,
gamma_blank: float = 1.0,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
@ -624,27 +625,12 @@ def greedy_search_batch(
logits_argmax = logits.argmax(dim=1)
logits_softmax = logits.softmax(dim=1)
# detailed in below fuction verify_non_blank_logits.
selection_verification = True
# 0 for blank frame and 1 for non-blank frame.
non_blank_flag[start:end] = torch.where(
logits_argmax == blank_id, 0, 1
logits_softmax[:, 0] >= gamma_blank, 0, 1
)
if False:
# In paper: https://arxiv.org/pdf/2101.06856.pdf
# A gama_blank threshold value is used to determinze blank frame.
# Currently, results are worse than baseline greedy_search
# and also very sensitive to gama_blank.
# (TODO): debug this later.
gama_blank = 0.50
non_blank_flag[start:end] = torch.where(
logits_softmax[:, 0] >= gama_blank, 0, 1
)
# function verify_non_blank_logits only works with logits_argmax == blank_id.
selection_verification = False
y = logits.argmax(dim=1).tolist()
emitted = False
@ -710,21 +696,10 @@ def greedy_search_batch(
all_logits_unpacked[i, :], 0, cur_non_blank_index
)
def verify_non_blank_logits():
# A way to verify non_blank_logits are selected correctly from all_logits.
hyps_before_rescore = non_blank_logits.argmax(dim=2)
for i in range(N):
usi = unsorted_indices[i]
hyp_to_verify = hyps_before_rescore[usi][
: int(non_blank_logits_lens[usi])
].tolist()
assert ans[i] == hyp_to_verify
logging.info("Verified non-blank logits.")
# TODO: skip verification after we finally get a workable rescoring method.
if selection_verification:
verify_non_blank_logits()
number_selected_frames = non_blank_flag.sum()
logging.info(f"{number_selected_frames} are selected out of {total_t} frames")
# Split log_softmax into two seperate steps,
# so we cound do blank deweight in probability domain if needed.
logits_to_rescore_softmax = non_blank_logits.softmax(dim=2)
@ -736,7 +711,7 @@ def greedy_search_batch(
# So just put this blank deweight before ngram rescoring.
# (TODO): debug this blank deweight issue.
blank_deweight = 100
blank_deweight = 0.0
logits_to_rescore[:, :, 0] -= blank_deweight
supervision_segments = torch.zeros([N, 3], dtype=torch.int32)
@ -754,7 +729,7 @@ def greedy_search_batch(
subsampling_factor=1,
)
lm_weight = 0.3 # (TODO): tuning this.
lm_weight = 0.5 # (TODO): tuning this.
lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight)
best_path = one_best_decoding(
@ -762,7 +737,7 @@ def greedy_search_batch(
use_double_scores=True,
)
token_ids = get_alignments(best_path, "labels")
token_ids = get_alignments(best_path, "labels", remove_zero_blank=True)
ans = []
for i in range(N):

View File

@ -241,6 +241,12 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--gamma-blank",
type=int,
default=1.0,
)
return parser
@ -317,6 +323,7 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens,
decoding_graph=decoding_graph,
ngram_rescoring=params.ngram_rescoring,
gamma_blank=params.gamma_blank,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -525,6 +532,7 @@ def main():
if params.ngram_rescoring:
params.suffix += "-ngram-rescoring"
params.suffix += f"-{params.decoding_graph}"
params.suffix += f"-gamma_blank-{params.gamma_blank}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
@ -636,8 +644,7 @@ def main():
if params.ngram_rescoring and params.decoding_method == "greedy_search":
assert params.decoding_graph in [
"trivial_graph",
"HLG",
"Trivial_LG",
"L",
], f"Unsupported decoding graph {params.decoding_graph}"
if params.decoding_graph == "trivial_graph":
decoding_graph = k2.trivial_graph(
@ -650,6 +657,7 @@ def main():
map_location=device,
)
)
decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
decoding_graph.lm_scores = decoding_graph.scores.clone()

View File

@ -236,7 +236,9 @@ def get_texts(
return aux_labels.tolist()
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
def get_alignments(
best_paths: k2.Fsa, kind: str, remove_zero_blank: bool = False
) -> List[List[int]]:
"""Extract labels or aux_labels from the best-path FSAs.
Args:
@ -272,6 +274,8 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
token_shape, getattr(best_paths, kind).contiguous()
)
tokens = tokens.remove_values_eq(-1)
if remove_zero_blank:
tokens = tokens.remove_values_eq(0)
return tokens.tolist()