mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
add params.hlg_scale (#880)
This commit is contained in:
parent
caf23546ed
commit
5a05b95730
@ -58,7 +58,6 @@ For example:
|
|||||||
--left-context 64 \
|
--left-context 64 \
|
||||||
--manifest-dir data/fbank_ali
|
--manifest-dir data/fbank_ali
|
||||||
Note: It supports calculating symbol delay with following decoding methods:
|
Note: It supports calculating symbol delay with following decoding methods:
|
||||||
- ctc-greedy-search
|
|
||||||
- ctc-decoding
|
- ctc-decoding
|
||||||
- 1best
|
- 1best
|
||||||
"""
|
"""
|
||||||
@ -96,10 +95,8 @@ from icefall.decode import (
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
DecodingResults,
|
|
||||||
get_texts,
|
get_texts,
|
||||||
get_texts_with_timestamp,
|
get_texts_with_timestamp,
|
||||||
make_pad_mask,
|
|
||||||
parse_hyp_and_timestamp,
|
parse_hyp_and_timestamp,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts_and_timestamps,
|
store_transcripts_and_timestamps,
|
||||||
@ -177,20 +174,18 @@ def get_parser():
|
|||||||
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||||
It needs neither a lexicon nor an n-gram LM.
|
It needs neither a lexicon nor an n-gram LM.
|
||||||
- (1) ctc-greedy-search. It only use CTC output and a sentence piece
|
- (1) 1best. Extract the best path from the decoding lattice as the
|
||||||
model for decoding. It produces the same results with ctc-decoding.
|
|
||||||
- (2) 1best. Extract the best path from the decoding lattice as the
|
|
||||||
decoding result.
|
decoding result.
|
||||||
- (3) nbest. Extract n paths from the decoding lattice; the path
|
- (2) nbest. Extract n paths from the decoding lattice; the path
|
||||||
with the highest score is the decoding result.
|
with the highest score is the decoding result.
|
||||||
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
|
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||||
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||||
the highest score is the decoding result.
|
the highest score is the decoding result.
|
||||||
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
|
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||||
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||||
is the decoding result.
|
is the decoding result.
|
||||||
you have trained an RNN LM using ./rnn_lm/train.py
|
you have trained an RNN LM using ./rnn_lm/train.py
|
||||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
- (5) nbest-oracle. Its WER is the lower bound of any n-best
|
||||||
rescoring method can achieve. Useful for debugging n-best
|
rescoring method can achieve. Useful for debugging n-best
|
||||||
rescoring method.
|
rescoring method.
|
||||||
""",
|
""",
|
||||||
@ -250,6 +245,14 @@ def get_parser():
|
|||||||
help="left context can be seen during decoding (in frames after subsampling)",
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hlg-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.8,
|
||||||
|
help="""The scale to be applied to `hlg.scores`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -270,47 +273,6 @@ def get_decoding_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def ctc_greedy_search(
|
|
||||||
ctc_probs: torch.Tensor,
|
|
||||||
nnet_output_lens: torch.Tensor,
|
|
||||||
) -> List[List[int]]:
|
|
||||||
"""Apply CTC greedy search
|
|
||||||
Args:
|
|
||||||
ctc_probs (torch.Tensor): (batch, max_len, feat_dim)
|
|
||||||
nnet_output_lens (torch.Tensor): (batch, )
|
|
||||||
Returns:
|
|
||||||
List[List[int]]: best path result
|
|
||||||
"""
|
|
||||||
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
|
|
||||||
topk_index = topk_index.squeeze(2) # (B, maxlen)
|
|
||||||
mask = make_pad_mask(nnet_output_lens)
|
|
||||||
topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen)
|
|
||||||
hyps = [hyp.tolist() for hyp in topk_index]
|
|
||||||
scores = topk_prob.max(1)
|
|
||||||
ret_hyps = []
|
|
||||||
timestamps = []
|
|
||||||
for i in range(len(hyps)):
|
|
||||||
hyp, time = remove_duplicates_and_blank(hyps[i])
|
|
||||||
ret_hyps.append(hyp)
|
|
||||||
timestamps.append(time)
|
|
||||||
return ret_hyps, timestamps, scores
|
|
||||||
|
|
||||||
|
|
||||||
def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]:
|
|
||||||
# modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
|
|
||||||
new_hyp: List[int] = []
|
|
||||||
time: List[int] = []
|
|
||||||
cur = 0
|
|
||||||
while cur < len(hyp):
|
|
||||||
if hyp[cur] != 0:
|
|
||||||
new_hyp.append(hyp[cur])
|
|
||||||
time.append(cur)
|
|
||||||
prev = cur
|
|
||||||
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
|
||||||
cur += 1
|
|
||||||
return new_hyp, time
|
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -402,26 +364,11 @@ def decode_one_batch(
|
|||||||
nnet_output = model.get_ctc_output(encoder_out)
|
nnet_output = model.get_ctc_output(encoder_out)
|
||||||
# nnet_output is (N, T, C)
|
# nnet_output is (N, T, C)
|
||||||
|
|
||||||
if params.decoding_method == "ctc-greedy-search":
|
|
||||||
hyps, timestamps, _ = ctc_greedy_search(
|
|
||||||
nnet_output,
|
|
||||||
encoder_out_lens,
|
|
||||||
)
|
|
||||||
res = DecodingResults(hyps=hyps, timestamps=timestamps)
|
|
||||||
hyps, timestamps = parse_hyp_and_timestamp(
|
|
||||||
res=res,
|
|
||||||
sp=bpe_model,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
frame_shift_ms=params.frame_shift_ms,
|
|
||||||
)
|
|
||||||
key = "ctc-greedy-search"
|
|
||||||
return {key: (hyps, timestamps)}
|
|
||||||
|
|
||||||
supervision_segments = torch.stack(
|
supervision_segments = torch.stack(
|
||||||
(
|
(
|
||||||
supervisions["sequence_idx"],
|
supervisions["sequence_idx"],
|
||||||
supervisions["start_frame"] // params.subsampling_factor,
|
supervisions["start_frame"] // params.subsampling_factor,
|
||||||
supervisions["num_frames"] // params.subsampling_factor,
|
encoder_out_lens.cpu(),
|
||||||
),
|
),
|
||||||
1,
|
1,
|
||||||
).to(torch.int32)
|
).to(torch.int32)
|
||||||
@ -434,75 +381,6 @@ def decode_one_batch(
|
|||||||
assert bpe_model is not None
|
assert bpe_model is not None
|
||||||
decoding_graph = H
|
decoding_graph = H
|
||||||
|
|
||||||
if params.decoding_method in ["1best", "nbest", "nbest-oracle"]:
|
|
||||||
hlg_scale_list = [0.2, 0.4, 0.6, 0.8, 1.0]
|
|
||||||
|
|
||||||
ori_scores = decoding_graph.scores.clone()
|
|
||||||
|
|
||||||
ans = {}
|
|
||||||
for hlg_scale in hlg_scale_list:
|
|
||||||
decoding_graph.scores = ori_scores * hlg_scale
|
|
||||||
lattice = get_lattice(
|
|
||||||
nnet_output=nnet_output,
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
supervision_segments=supervision_segments,
|
|
||||||
search_beam=params.search_beam,
|
|
||||||
output_beam=params.output_beam,
|
|
||||||
min_active_states=params.min_active_states,
|
|
||||||
max_active_states=params.max_active_states,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
)
|
|
||||||
key_suffix = f"-HLG-scale-{hlg_scale}"
|
|
||||||
|
|
||||||
if params.decoding_method == "nbest-oracle":
|
|
||||||
# Note: You can also pass rescored lattices to it.
|
|
||||||
# We choose the HLG decoded lattice for speed reasons
|
|
||||||
# as HLG decoding is faster and the oracle WER
|
|
||||||
# is only slightly worse than that of rescored lattices.
|
|
||||||
best_path = nbest_oracle(
|
|
||||||
lattice=lattice,
|
|
||||||
num_paths=params.num_paths,
|
|
||||||
ref_texts=supervisions["text"],
|
|
||||||
word_table=word_table,
|
|
||||||
nbest_scale=params.nbest_scale,
|
|
||||||
oov="<UNK>",
|
|
||||||
)
|
|
||||||
hyps = get_texts(best_path)
|
|
||||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
|
||||||
key = f"oracle-{params.num_paths}-nbest-scale-{params.nbest_scale}" # noqa
|
|
||||||
timestamps = [[] for _ in range(len(hyps))]
|
|
||||||
ans[key + key_suffix] = (hyps, timestamps)
|
|
||||||
|
|
||||||
elif params.decoding_method in ["1best", "nbest"]:
|
|
||||||
if params.decoding_method == "1best":
|
|
||||||
best_path = one_best_decoding(
|
|
||||||
lattice=lattice,
|
|
||||||
use_double_scores=params.use_double_scores,
|
|
||||||
)
|
|
||||||
key = "no-rescore"
|
|
||||||
res = get_texts_with_timestamp(best_path)
|
|
||||||
hyps, timestamps = parse_hyp_and_timestamp(
|
|
||||||
res=res,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
frame_shift_ms=params.frame_shift_ms,
|
|
||||||
word_table=word_table,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
best_path = nbest_decoding(
|
|
||||||
lattice=lattice,
|
|
||||||
num_paths=params.num_paths,
|
|
||||||
use_double_scores=params.use_double_scores,
|
|
||||||
nbest_scale=params.nbest_scale,
|
|
||||||
)
|
|
||||||
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
|
||||||
hyps = get_texts(best_path)
|
|
||||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
|
||||||
timestamps = [[] for _ in range(len(hyps))]
|
|
||||||
|
|
||||||
ans[key + key_suffix] = (hyps, timestamps)
|
|
||||||
|
|
||||||
return ans
|
|
||||||
|
|
||||||
lattice = get_lattice(
|
lattice = get_lattice(
|
||||||
nnet_output=nnet_output,
|
nnet_output=nnet_output,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
@ -532,6 +410,51 @@ def decode_one_batch(
|
|||||||
key = "ctc-decoding"
|
key = "ctc-decoding"
|
||||||
return {key: (hyps, timestamps)}
|
return {key: (hyps, timestamps)}
|
||||||
|
|
||||||
|
if params.decoding_method == "nbest-oracle":
|
||||||
|
# Note: You can also pass rescored lattices to it.
|
||||||
|
# We choose the HLG decoded lattice for speed reasons
|
||||||
|
# as HLG decoding is faster and the oracle WER
|
||||||
|
# is only slightly worse than that of rescored lattices.
|
||||||
|
best_path = nbest_oracle(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=supervisions["text"],
|
||||||
|
word_table=word_table,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
oov="<UNK>",
|
||||||
|
)
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
|
timestamps = [[] for _ in range(len(hyps))]
|
||||||
|
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}_hlg_scale_{params.hlg_scale}" # noqa
|
||||||
|
return {key: (hyps, timestamps)}
|
||||||
|
|
||||||
|
if params.decoding_method in ["1best", "nbest"]:
|
||||||
|
if params.decoding_method == "1best":
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
key = f"no_rescore_hlg_scale_{params.hlg_scale}"
|
||||||
|
res = get_texts_with_timestamp(best_path)
|
||||||
|
hyps, timestamps = parse_hyp_and_timestamp(
|
||||||
|
res=res,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
frame_shift_ms=params.frame_shift_ms,
|
||||||
|
word_table=word_table,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
best_path = nbest_decoding(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
use_double_scores=params.use_double_scores,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}-hlg-scale-{params.hlg_scale}" # noqa
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
|
timestamps = [[] for _ in range(len(hyps))]
|
||||||
|
return {key: (hyps, timestamps)}
|
||||||
|
|
||||||
assert params.decoding_method in [
|
assert params.decoding_method in [
|
||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
@ -757,7 +680,6 @@ def main():
|
|||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
assert params.decoding_method in (
|
assert params.decoding_method in (
|
||||||
"ctc-greedy-search",
|
|
||||||
"ctc-decoding",
|
"ctc-decoding",
|
||||||
"1best",
|
"1best",
|
||||||
"nbest",
|
"nbest",
|
||||||
@ -811,7 +733,7 @@ def main():
|
|||||||
params.sos_id = sos_id
|
params.sos_id = sos_id
|
||||||
params.eos_id = eos_id
|
params.eos_id = eos_id
|
||||||
|
|
||||||
if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]:
|
if params.decoding_method == "ctc-decoding":
|
||||||
HLG = None
|
HLG = None
|
||||||
H = k2.ctc_topo(
|
H = k2.ctc_topo(
|
||||||
max_token=max_token_id,
|
max_token=max_token_id,
|
||||||
@ -828,6 +750,7 @@ def main():
|
|||||||
)
|
)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
|
HLG.scores *= params.hlg_scale
|
||||||
if not hasattr(HLG, "lm_scores"):
|
if not hasattr(HLG, "lm_scores"):
|
||||||
HLG.lm_scores = HLG.scores.clone()
|
HLG.lm_scores = HLG.scores.clone()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user