mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
add params.hlg_scale (#880)
This commit is contained in:
parent
caf23546ed
commit
5a05b95730
@ -58,7 +58,6 @@ For example:
|
||||
--left-context 64 \
|
||||
--manifest-dir data/fbank_ali
|
||||
Note: It supports calculating symbol delay with following decoding methods:
|
||||
- ctc-greedy-search
|
||||
- ctc-decoding
|
||||
- 1best
|
||||
"""
|
||||
@ -96,10 +95,8 @@ from icefall.decode import (
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
DecodingResults,
|
||||
get_texts,
|
||||
get_texts_with_timestamp,
|
||||
make_pad_mask,
|
||||
parse_hyp_and_timestamp,
|
||||
setup_logger,
|
||||
store_transcripts_and_timestamps,
|
||||
@ -177,20 +174,18 @@ def get_parser():
|
||||
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
- (1) ctc-greedy-search. It only use CTC output and a sentence piece
|
||||
model for decoding. It produces the same results with ctc-decoding.
|
||||
- (2) 1best. Extract the best path from the decoding lattice as the
|
||||
- (1) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
- (3) nbest. Extract n paths from the decoding lattice; the path
|
||||
- (2) nbest. Extract n paths from the decoding lattice; the path
|
||||
with the highest score is the decoding result.
|
||||
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||
the highest score is the decoding result.
|
||||
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||
is the decoding result.
|
||||
you have trained an RNN LM using ./rnn_lm/train.py
|
||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
- (5) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
""",
|
||||
@ -250,6 +245,14 @@ def get_parser():
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hlg-scale",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="""The scale to be applied to `hlg.scores`.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -270,47 +273,6 @@ def get_decoding_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def ctc_greedy_search(
|
||||
ctc_probs: torch.Tensor,
|
||||
nnet_output_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Apply CTC greedy search
|
||||
Args:
|
||||
ctc_probs (torch.Tensor): (batch, max_len, feat_dim)
|
||||
nnet_output_lens (torch.Tensor): (batch, )
|
||||
Returns:
|
||||
List[List[int]]: best path result
|
||||
"""
|
||||
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
|
||||
topk_index = topk_index.squeeze(2) # (B, maxlen)
|
||||
mask = make_pad_mask(nnet_output_lens)
|
||||
topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen)
|
||||
hyps = [hyp.tolist() for hyp in topk_index]
|
||||
scores = topk_prob.max(1)
|
||||
ret_hyps = []
|
||||
timestamps = []
|
||||
for i in range(len(hyps)):
|
||||
hyp, time = remove_duplicates_and_blank(hyps[i])
|
||||
ret_hyps.append(hyp)
|
||||
timestamps.append(time)
|
||||
return ret_hyps, timestamps, scores
|
||||
|
||||
|
||||
def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]:
|
||||
# modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
|
||||
new_hyp: List[int] = []
|
||||
time: List[int] = []
|
||||
cur = 0
|
||||
while cur < len(hyp):
|
||||
if hyp[cur] != 0:
|
||||
new_hyp.append(hyp[cur])
|
||||
time.append(cur)
|
||||
prev = cur
|
||||
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
||||
cur += 1
|
||||
return new_hyp, time
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -402,26 +364,11 @@ def decode_one_batch(
|
||||
nnet_output = model.get_ctc_output(encoder_out)
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
if params.decoding_method == "ctc-greedy-search":
|
||||
hyps, timestamps, _ = ctc_greedy_search(
|
||||
nnet_output,
|
||||
encoder_out_lens,
|
||||
)
|
||||
res = DecodingResults(hyps=hyps, timestamps=timestamps)
|
||||
hyps, timestamps = parse_hyp_and_timestamp(
|
||||
res=res,
|
||||
sp=bpe_model,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
frame_shift_ms=params.frame_shift_ms,
|
||||
)
|
||||
key = "ctc-greedy-search"
|
||||
return {key: (hyps, timestamps)}
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
supervisions["start_frame"] // params.subsampling_factor,
|
||||
supervisions["num_frames"] // params.subsampling_factor,
|
||||
encoder_out_lens.cpu(),
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
@ -434,75 +381,6 @@ def decode_one_batch(
|
||||
assert bpe_model is not None
|
||||
decoding_graph = H
|
||||
|
||||
if params.decoding_method in ["1best", "nbest", "nbest-oracle"]:
|
||||
hlg_scale_list = [0.2, 0.4, 0.6, 0.8, 1.0]
|
||||
|
||||
ori_scores = decoding_graph.scores.clone()
|
||||
|
||||
ans = {}
|
||||
for hlg_scale in hlg_scale_list:
|
||||
decoding_graph.scores = ori_scores * hlg_scale
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=decoding_graph,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
key_suffix = f"-HLG-scale-{hlg_scale}"
|
||||
|
||||
if params.decoding_method == "nbest-oracle":
|
||||
# Note: You can also pass rescored lattices to it.
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
# as HLG decoding is faster and the oracle WER
|
||||
# is only slightly worse than that of rescored lattices.
|
||||
best_path = nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
nbest_scale=params.nbest_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
key = f"oracle-{params.num_paths}-nbest-scale-{params.nbest_scale}" # noqa
|
||||
timestamps = [[] for _ in range(len(hyps))]
|
||||
ans[key + key_suffix] = (hyps, timestamps)
|
||||
|
||||
elif params.decoding_method in ["1best", "nbest"]:
|
||||
if params.decoding_method == "1best":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice,
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
key = "no-rescore"
|
||||
res = get_texts_with_timestamp(best_path)
|
||||
hyps, timestamps = parse_hyp_and_timestamp(
|
||||
res=res,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
frame_shift_ms=params.frame_shift_ms,
|
||||
word_table=word_table,
|
||||
)
|
||||
else:
|
||||
best_path = nbest_decoding(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
timestamps = [[] for _ in range(len(hyps))]
|
||||
|
||||
ans[key + key_suffix] = (hyps, timestamps)
|
||||
|
||||
return ans
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=decoding_graph,
|
||||
@ -532,6 +410,51 @@ def decode_one_batch(
|
||||
key = "ctc-decoding"
|
||||
return {key: (hyps, timestamps)}
|
||||
|
||||
if params.decoding_method == "nbest-oracle":
|
||||
# Note: You can also pass rescored lattices to it.
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
# as HLG decoding is faster and the oracle WER
|
||||
# is only slightly worse than that of rescored lattices.
|
||||
best_path = nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
nbest_scale=params.nbest_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
timestamps = [[] for _ in range(len(hyps))]
|
||||
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}_hlg_scale_{params.hlg_scale}" # noqa
|
||||
return {key: (hyps, timestamps)}
|
||||
|
||||
if params.decoding_method in ["1best", "nbest"]:
|
||||
if params.decoding_method == "1best":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
key = f"no_rescore_hlg_scale_{params.hlg_scale}"
|
||||
res = get_texts_with_timestamp(best_path)
|
||||
hyps, timestamps = parse_hyp_and_timestamp(
|
||||
res=res,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
frame_shift_ms=params.frame_shift_ms,
|
||||
word_table=word_table,
|
||||
)
|
||||
else:
|
||||
best_path = nbest_decoding(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}-hlg-scale-{params.hlg_scale}" # noqa
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
timestamps = [[] for _ in range(len(hyps))]
|
||||
return {key: (hyps, timestamps)}
|
||||
|
||||
assert params.decoding_method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
@ -757,7 +680,6 @@ def main():
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"ctc-greedy-search",
|
||||
"ctc-decoding",
|
||||
"1best",
|
||||
"nbest",
|
||||
@ -811,7 +733,7 @@ def main():
|
||||
params.sos_id = sos_id
|
||||
params.eos_id = eos_id
|
||||
|
||||
if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]:
|
||||
if params.decoding_method == "ctc-decoding":
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
@ -828,6 +750,7 @@ def main():
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
HLG.scores *= params.hlg_scale
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user