add params.hlg_scale (#880)

This commit is contained in:
Zengwei Yao 2023-02-06 23:21:46 +08:00 committed by GitHub
parent caf23546ed
commit 5a05b95730
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -58,7 +58,6 @@ For example:
--left-context 64 \ --left-context 64 \
--manifest-dir data/fbank_ali --manifest-dir data/fbank_ali
Note: It supports calculating symbol delay with following decoding methods: Note: It supports calculating symbol delay with following decoding methods:
- ctc-greedy-search
- ctc-decoding - ctc-decoding
- 1best - 1best
""" """
@ -96,10 +95,8 @@ from icefall.decode import (
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
DecodingResults,
get_texts, get_texts,
get_texts_with_timestamp, get_texts_with_timestamp,
make_pad_mask,
parse_hyp_and_timestamp, parse_hyp_and_timestamp,
setup_logger, setup_logger,
store_transcripts_and_timestamps, store_transcripts_and_timestamps,
@ -177,20 +174,18 @@ def get_parser():
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words. model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM. It needs neither a lexicon nor an n-gram LM.
- (1) ctc-greedy-search. It only use CTC output and a sentence piece - (1) 1best. Extract the best path from the decoding lattice as the
model for decoding. It produces the same results with ctc-decoding.
- (2) 1best. Extract the best path from the decoding lattice as the
decoding result. decoding result.
- (3) nbest. Extract n paths from the decoding lattice; the path - (2) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result. with the highest score is the decoding result.
- (4) nbest-rescoring. Extract n paths from the decoding lattice, - (3) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result. the highest score is the decoding result.
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an - (4) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result. is the decoding result.
you have trained an RNN LM using ./rnn_lm/train.py you have trained an RNN LM using ./rnn_lm/train.py
- (6) nbest-oracle. Its WER is the lower bound of any n-best - (5) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best rescoring method can achieve. Useful for debugging n-best
rescoring method. rescoring method.
""", """,
@ -250,6 +245,14 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)", help="left context can be seen during decoding (in frames after subsampling)",
) )
parser.add_argument(
"--hlg-scale",
type=float,
default=0.8,
help="""The scale to be applied to `hlg.scores`.
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -270,47 +273,6 @@ def get_decoding_params() -> AttributeDict:
return params return params
def ctc_greedy_search(
ctc_probs: torch.Tensor,
nnet_output_lens: torch.Tensor,
) -> List[List[int]]:
"""Apply CTC greedy search
Args:
ctc_probs (torch.Tensor): (batch, max_len, feat_dim)
nnet_output_lens (torch.Tensor): (batch, )
Returns:
List[List[int]]: best path result
"""
topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1)
topk_index = topk_index.squeeze(2) # (B, maxlen)
mask = make_pad_mask(nnet_output_lens)
topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen)
hyps = [hyp.tolist() for hyp in topk_index]
scores = topk_prob.max(1)
ret_hyps = []
timestamps = []
for i in range(len(hyps)):
hyp, time = remove_duplicates_and_blank(hyps[i])
ret_hyps.append(hyp)
timestamps.append(time)
return ret_hyps, timestamps, scores
def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]:
# modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
new_hyp: List[int] = []
time: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
time.append(cur)
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp, time
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -402,26 +364,11 @@ def decode_one_batch(
nnet_output = model.get_ctc_output(encoder_out) nnet_output = model.get_ctc_output(encoder_out)
# nnet_output is (N, T, C) # nnet_output is (N, T, C)
if params.decoding_method == "ctc-greedy-search":
hyps, timestamps, _ = ctc_greedy_search(
nnet_output,
encoder_out_lens,
)
res = DecodingResults(hyps=hyps, timestamps=timestamps)
hyps, timestamps = parse_hyp_and_timestamp(
res=res,
sp=bpe_model,
subsampling_factor=params.subsampling_factor,
frame_shift_ms=params.frame_shift_ms,
)
key = "ctc-greedy-search"
return {key: (hyps, timestamps)}
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (
supervisions["sequence_idx"], supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor, supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor, encoder_out_lens.cpu(),
), ),
1, 1,
).to(torch.int32) ).to(torch.int32)
@ -434,75 +381,6 @@ def decode_one_batch(
assert bpe_model is not None assert bpe_model is not None
decoding_graph = H decoding_graph = H
if params.decoding_method in ["1best", "nbest", "nbest-oracle"]:
hlg_scale_list = [0.2, 0.4, 0.6, 0.8, 1.0]
ori_scores = decoding_graph.scores.clone()
ans = {}
for hlg_scale in hlg_scale_list:
decoding_graph.scores = ori_scores * hlg_scale
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
key_suffix = f"-HLG-scale-{hlg_scale}"
if params.decoding_method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle-{params.num_paths}-nbest-scale-{params.nbest_scale}" # noqa
timestamps = [[] for _ in range(len(hyps))]
ans[key + key_suffix] = (hyps, timestamps)
elif params.decoding_method in ["1best", "nbest"]:
if params.decoding_method == "1best":
best_path = one_best_decoding(
lattice=lattice,
use_double_scores=params.use_double_scores,
)
key = "no-rescore"
res = get_texts_with_timestamp(best_path)
hyps, timestamps = parse_hyp_and_timestamp(
res=res,
subsampling_factor=params.subsampling_factor,
frame_shift_ms=params.frame_shift_ms,
word_table=word_table,
)
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
timestamps = [[] for _ in range(len(hyps))]
ans[key + key_suffix] = (hyps, timestamps)
return ans
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
@ -532,6 +410,51 @@ def decode_one_batch(
key = "ctc-decoding" key = "ctc-decoding"
return {key: (hyps, timestamps)} return {key: (hyps, timestamps)}
if params.decoding_method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
timestamps = [[] for _ in range(len(hyps))]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}_hlg_scale_{params.hlg_scale}" # noqa
return {key: (hyps, timestamps)}
if params.decoding_method in ["1best", "nbest"]:
if params.decoding_method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = f"no_rescore_hlg_scale_{params.hlg_scale}"
res = get_texts_with_timestamp(best_path)
hyps, timestamps = parse_hyp_and_timestamp(
res=res,
subsampling_factor=params.subsampling_factor,
frame_shift_ms=params.frame_shift_ms,
word_table=word_table,
)
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}-hlg-scale-{params.hlg_scale}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
timestamps = [[] for _ in range(len(hyps))]
return {key: (hyps, timestamps)}
assert params.decoding_method in [ assert params.decoding_method in [
"nbest-rescoring", "nbest-rescoring",
"whole-lattice-rescoring", "whole-lattice-rescoring",
@ -757,7 +680,6 @@ def main():
params.update(vars(args)) params.update(vars(args))
assert params.decoding_method in ( assert params.decoding_method in (
"ctc-greedy-search",
"ctc-decoding", "ctc-decoding",
"1best", "1best",
"nbest", "nbest",
@ -811,7 +733,7 @@ def main():
params.sos_id = sos_id params.sos_id = sos_id
params.eos_id = eos_id params.eos_id = eos_id
if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]: if params.decoding_method == "ctc-decoding":
HLG = None HLG = None
H = k2.ctc_topo( H = k2.ctc_topo(
max_token=max_token_id, max_token=max_token_id,
@ -828,6 +750,7 @@ def main():
) )
assert HLG.requires_grad is False assert HLG.requires_grad is False
HLG.scores *= params.hlg_scale
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone() HLG.lm_scores = HLG.scores.clone()