mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
add self loop to L
This commit is contained in:
parent
0a99ceb6ba
commit
473efcd531
@ -538,6 +538,7 @@ def greedy_search_batch(
|
||||
encoder_out_lens: torch.Tensor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_rescoring: bool = False,
|
||||
gamma_blank: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
@ -624,27 +625,12 @@ def greedy_search_batch(
|
||||
logits_argmax = logits.argmax(dim=1)
|
||||
logits_softmax = logits.softmax(dim=1)
|
||||
|
||||
# detailed in below fuction verify_non_blank_logits.
|
||||
selection_verification = True
|
||||
|
||||
# 0 for blank frame and 1 for non-blank frame.
|
||||
non_blank_flag[start:end] = torch.where(
|
||||
logits_argmax == blank_id, 0, 1
|
||||
logits_softmax[:, 0] >= gamma_blank, 0, 1
|
||||
)
|
||||
|
||||
if False:
|
||||
# In paper: https://arxiv.org/pdf/2101.06856.pdf
|
||||
# A gama_blank threshold value is used to determinze blank frame.
|
||||
# Currently, results are worse than baseline greedy_search
|
||||
# and also very sensitive to gama_blank.
|
||||
# (TODO): debug this later.
|
||||
gama_blank = 0.50
|
||||
non_blank_flag[start:end] = torch.where(
|
||||
logits_softmax[:, 0] >= gama_blank, 0, 1
|
||||
)
|
||||
|
||||
# function verify_non_blank_logits only works with logits_argmax == blank_id.
|
||||
selection_verification = False
|
||||
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
@ -710,21 +696,10 @@ def greedy_search_batch(
|
||||
all_logits_unpacked[i, :], 0, cur_non_blank_index
|
||||
)
|
||||
|
||||
def verify_non_blank_logits():
|
||||
# A way to verify non_blank_logits are selected correctly from all_logits.
|
||||
hyps_before_rescore = non_blank_logits.argmax(dim=2)
|
||||
for i in range(N):
|
||||
usi = unsorted_indices[i]
|
||||
hyp_to_verify = hyps_before_rescore[usi][
|
||||
: int(non_blank_logits_lens[usi])
|
||||
].tolist()
|
||||
assert ans[i] == hyp_to_verify
|
||||
logging.info("Verified non-blank logits.")
|
||||
|
||||
# TODO: skip verification after we finally get a workable rescoring method.
|
||||
if selection_verification:
|
||||
verify_non_blank_logits()
|
||||
|
||||
number_selected_frames = non_blank_flag.sum()
|
||||
logging.info(f"{number_selected_frames} are selected out of {total_t} frames")
|
||||
# Split log_softmax into two seperate steps,
|
||||
# so we cound do blank deweight in probability domain if needed.
|
||||
logits_to_rescore_softmax = non_blank_logits.softmax(dim=2)
|
||||
@ -736,7 +711,7 @@ def greedy_search_batch(
|
||||
# So just put this blank deweight before ngram rescoring.
|
||||
# (TODO): debug this blank deweight issue.
|
||||
|
||||
blank_deweight = 100
|
||||
blank_deweight = 0.0
|
||||
logits_to_rescore[:, :, 0] -= blank_deweight
|
||||
|
||||
supervision_segments = torch.zeros([N, 3], dtype=torch.int32)
|
||||
@ -754,7 +729,7 @@ def greedy_search_batch(
|
||||
subsampling_factor=1,
|
||||
)
|
||||
|
||||
lm_weight = 0.3 # (TODO): tuning this.
|
||||
lm_weight = 0.5 # (TODO): tuning this.
|
||||
lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
@ -762,7 +737,7 @@ def greedy_search_batch(
|
||||
use_double_scores=True,
|
||||
)
|
||||
|
||||
token_ids = get_alignments(best_path, "labels")
|
||||
token_ids = get_alignments(best_path, "labels", remove_zero_blank=True)
|
||||
|
||||
ans = []
|
||||
for i in range(N):
|
||||
|
@ -241,6 +241,12 @@ def get_parser():
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gamma-blank",
|
||||
type=int,
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -317,6 +323,7 @@ def decode_one_batch(
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
decoding_graph=decoding_graph,
|
||||
ngram_rescoring=params.ngram_rescoring,
|
||||
gamma_blank=params.gamma_blank,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -525,6 +532,7 @@ def main():
|
||||
if params.ngram_rescoring:
|
||||
params.suffix += "-ngram-rescoring"
|
||||
params.suffix += f"-{params.decoding_graph}"
|
||||
params.suffix += f"-gamma_blank-{params.gamma_blank}"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
@ -636,8 +644,7 @@ def main():
|
||||
if params.ngram_rescoring and params.decoding_method == "greedy_search":
|
||||
assert params.decoding_graph in [
|
||||
"trivial_graph",
|
||||
"HLG",
|
||||
"Trivial_LG",
|
||||
"L",
|
||||
], f"Unsupported decoding graph {params.decoding_graph}"
|
||||
if params.decoding_graph == "trivial_graph":
|
||||
decoding_graph = k2.trivial_graph(
|
||||
@ -650,6 +657,7 @@ def main():
|
||||
map_location=device,
|
||||
)
|
||||
)
|
||||
decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
|
||||
|
||||
decoding_graph.lm_scores = decoding_graph.scores.clone()
|
||||
|
||||
|
@ -236,7 +236,9 @@ def get_texts(
|
||||
return aux_labels.tolist()
|
||||
|
||||
|
||||
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
||||
def get_alignments(
|
||||
best_paths: k2.Fsa, kind: str, remove_zero_blank: bool = False
|
||||
) -> List[List[int]]:
|
||||
"""Extract labels or aux_labels from the best-path FSAs.
|
||||
|
||||
Args:
|
||||
@ -272,6 +274,8 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
||||
token_shape, getattr(best_paths, kind).contiguous()
|
||||
)
|
||||
tokens = tokens.remove_values_eq(-1)
|
||||
if remove_zero_blank:
|
||||
tokens = tokens.remove_values_eq(0)
|
||||
return tokens.tolist()
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user