mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
add self loop to L
This commit is contained in:
parent
0a99ceb6ba
commit
473efcd531
@ -538,6 +538,7 @@ def greedy_search_batch(
|
|||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
ngram_rescoring: bool = False,
|
ngram_rescoring: bool = False,
|
||||||
|
gamma_blank: float = 1.0,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
Args:
|
Args:
|
||||||
@ -624,27 +625,12 @@ def greedy_search_batch(
|
|||||||
logits_argmax = logits.argmax(dim=1)
|
logits_argmax = logits.argmax(dim=1)
|
||||||
logits_softmax = logits.softmax(dim=1)
|
logits_softmax = logits.softmax(dim=1)
|
||||||
|
|
||||||
# detailed in below fuction verify_non_blank_logits.
|
|
||||||
selection_verification = True
|
|
||||||
|
|
||||||
# 0 for blank frame and 1 for non-blank frame.
|
# 0 for blank frame and 1 for non-blank frame.
|
||||||
non_blank_flag[start:end] = torch.where(
|
non_blank_flag[start:end] = torch.where(
|
||||||
logits_argmax == blank_id, 0, 1
|
logits_softmax[:, 0] >= gamma_blank, 0, 1
|
||||||
)
|
)
|
||||||
|
|
||||||
if False:
|
|
||||||
# In paper: https://arxiv.org/pdf/2101.06856.pdf
|
|
||||||
# A gama_blank threshold value is used to determinze blank frame.
|
|
||||||
# Currently, results are worse than baseline greedy_search
|
|
||||||
# and also very sensitive to gama_blank.
|
|
||||||
# (TODO): debug this later.
|
|
||||||
gama_blank = 0.50
|
|
||||||
non_blank_flag[start:end] = torch.where(
|
|
||||||
logits_softmax[:, 0] >= gama_blank, 0, 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# function verify_non_blank_logits only works with logits_argmax == blank_id.
|
|
||||||
selection_verification = False
|
|
||||||
|
|
||||||
y = logits.argmax(dim=1).tolist()
|
y = logits.argmax(dim=1).tolist()
|
||||||
emitted = False
|
emitted = False
|
||||||
@ -710,21 +696,10 @@ def greedy_search_batch(
|
|||||||
all_logits_unpacked[i, :], 0, cur_non_blank_index
|
all_logits_unpacked[i, :], 0, cur_non_blank_index
|
||||||
)
|
)
|
||||||
|
|
||||||
def verify_non_blank_logits():
|
|
||||||
# A way to verify non_blank_logits are selected correctly from all_logits.
|
|
||||||
hyps_before_rescore = non_blank_logits.argmax(dim=2)
|
|
||||||
for i in range(N):
|
|
||||||
usi = unsorted_indices[i]
|
|
||||||
hyp_to_verify = hyps_before_rescore[usi][
|
|
||||||
: int(non_blank_logits_lens[usi])
|
|
||||||
].tolist()
|
|
||||||
assert ans[i] == hyp_to_verify
|
|
||||||
logging.info("Verified non-blank logits.")
|
|
||||||
|
|
||||||
# TODO: skip verification after we finally get a workable rescoring method.
|
|
||||||
if selection_verification:
|
|
||||||
verify_non_blank_logits()
|
|
||||||
|
|
||||||
|
number_selected_frames = non_blank_flag.sum()
|
||||||
|
logging.info(f"{number_selected_frames} are selected out of {total_t} frames")
|
||||||
# Split log_softmax into two seperate steps,
|
# Split log_softmax into two seperate steps,
|
||||||
# so we cound do blank deweight in probability domain if needed.
|
# so we cound do blank deweight in probability domain if needed.
|
||||||
logits_to_rescore_softmax = non_blank_logits.softmax(dim=2)
|
logits_to_rescore_softmax = non_blank_logits.softmax(dim=2)
|
||||||
@ -736,7 +711,7 @@ def greedy_search_batch(
|
|||||||
# So just put this blank deweight before ngram rescoring.
|
# So just put this blank deweight before ngram rescoring.
|
||||||
# (TODO): debug this blank deweight issue.
|
# (TODO): debug this blank deweight issue.
|
||||||
|
|
||||||
blank_deweight = 100
|
blank_deweight = 0.0
|
||||||
logits_to_rescore[:, :, 0] -= blank_deweight
|
logits_to_rescore[:, :, 0] -= blank_deweight
|
||||||
|
|
||||||
supervision_segments = torch.zeros([N, 3], dtype=torch.int32)
|
supervision_segments = torch.zeros([N, 3], dtype=torch.int32)
|
||||||
@ -754,7 +729,7 @@ def greedy_search_batch(
|
|||||||
subsampling_factor=1,
|
subsampling_factor=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
lm_weight = 0.3 # (TODO): tuning this.
|
lm_weight = 0.5 # (TODO): tuning this.
|
||||||
lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight)
|
lattice.scores = lattice.scores - lattice.lm_scores * (1 - lm_weight)
|
||||||
|
|
||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
@ -762,7 +737,7 @@ def greedy_search_batch(
|
|||||||
use_double_scores=True,
|
use_double_scores=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
token_ids = get_alignments(best_path, "labels")
|
token_ids = get_alignments(best_path, "labels", remove_zero_blank=True)
|
||||||
|
|
||||||
ans = []
|
ans = []
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
|
@ -241,6 +241,12 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--gamma-blank",
|
||||||
|
type=int,
|
||||||
|
default=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -317,6 +323,7 @@ def decode_one_batch(
|
|||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
ngram_rescoring=params.ngram_rescoring,
|
ngram_rescoring=params.ngram_rescoring,
|
||||||
|
gamma_blank=params.gamma_blank,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -525,6 +532,7 @@ def main():
|
|||||||
if params.ngram_rescoring:
|
if params.ngram_rescoring:
|
||||||
params.suffix += "-ngram-rescoring"
|
params.suffix += "-ngram-rescoring"
|
||||||
params.suffix += f"-{params.decoding_graph}"
|
params.suffix += f"-{params.decoding_graph}"
|
||||||
|
params.suffix += f"-gamma_blank-{params.gamma_blank}"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
@ -636,8 +644,7 @@ def main():
|
|||||||
if params.ngram_rescoring and params.decoding_method == "greedy_search":
|
if params.ngram_rescoring and params.decoding_method == "greedy_search":
|
||||||
assert params.decoding_graph in [
|
assert params.decoding_graph in [
|
||||||
"trivial_graph",
|
"trivial_graph",
|
||||||
"HLG",
|
"L",
|
||||||
"Trivial_LG",
|
|
||||||
], f"Unsupported decoding graph {params.decoding_graph}"
|
], f"Unsupported decoding graph {params.decoding_graph}"
|
||||||
if params.decoding_graph == "trivial_graph":
|
if params.decoding_graph == "trivial_graph":
|
||||||
decoding_graph = k2.trivial_graph(
|
decoding_graph = k2.trivial_graph(
|
||||||
@ -650,6 +657,7 @@ def main():
|
|||||||
map_location=device,
|
map_location=device,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
|
||||||
|
|
||||||
decoding_graph.lm_scores = decoding_graph.scores.clone()
|
decoding_graph.lm_scores = decoding_graph.scores.clone()
|
||||||
|
|
||||||
|
@ -236,7 +236,9 @@ def get_texts(
|
|||||||
return aux_labels.tolist()
|
return aux_labels.tolist()
|
||||||
|
|
||||||
|
|
||||||
def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
def get_alignments(
|
||||||
|
best_paths: k2.Fsa, kind: str, remove_zero_blank: bool = False
|
||||||
|
) -> List[List[int]]:
|
||||||
"""Extract labels or aux_labels from the best-path FSAs.
|
"""Extract labels or aux_labels from the best-path FSAs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -272,6 +274,8 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]:
|
|||||||
token_shape, getattr(best_paths, kind).contiguous()
|
token_shape, getattr(best_paths, kind).contiguous()
|
||||||
)
|
)
|
||||||
tokens = tokens.remove_values_eq(-1)
|
tokens = tokens.remove_values_eq(-1)
|
||||||
|
if remove_zero_blank:
|
||||||
|
tokens = tokens.remove_values_eq(0)
|
||||||
return tokens.tolist()
|
return tokens.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user