mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Fix the score of the nbest attention rescorer
This commit is contained in:
parent
1ac52e5bcb
commit
355e3244c8
@ -125,8 +125,8 @@ def get_parser():
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
# "exp_dir": Path("exp/conformer_ctc"),
|
||||
"exp_dir": Path("conformer_ctc/exp"),
|
||||
"exp_dir": Path("exp/conformer_ctc"),
|
||||
# "exp_dir": Path("conformer_ctc/exp"),
|
||||
"lang_dir": Path("data/lang_bpe"),
|
||||
"lm_dir": Path("data/lm"),
|
||||
"feature_dim": 80,
|
||||
@ -330,6 +330,7 @@ def decode_one_batch(
|
||||
rescore_est_model=rescore_est_model,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
scale=params.lattice_score_scale,
|
||||
)
|
||||
if params.dump_best_matching_feature:
|
||||
if best_path_dict.size()[0] > 0:
|
||||
@ -612,13 +613,13 @@ def main():
|
||||
#
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||
if test_set == "test-other": continue
|
||||
# if test_set == "test-other": continue
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
word_table=word_table,
|
||||
word_table=lexicon.word_table,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
|
@ -932,7 +932,7 @@ def rescore_nbest_with_attention_decoder(
|
||||
token_seq = k2.ragged.remove_values_leq(token_seq, -1)
|
||||
# token seq shape [utt][path][token]
|
||||
token_seq = k2.ragged.remove_axis(token_seq, 2)
|
||||
# token seq shape [utt][token]
|
||||
# token seq shape [path][token]
|
||||
token_seq = k2.ragged.remove_axis(token_seq, 0)
|
||||
|
||||
token_ids = k2.ragged.to_list(token_seq)
|
||||
@ -944,7 +944,6 @@ def rescore_nbest_with_attention_decoder(
|
||||
0, path_to_seq_map_long
|
||||
)
|
||||
|
||||
# TODO: pass the sos_token_id and eos_token_id via function arguments
|
||||
nll = model.decoder_nll(
|
||||
memory=expanded_memory,
|
||||
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||
@ -987,6 +986,9 @@ def rescore_with_attention_decoder_v2(
|
||||
rescore_est_model: nn.Module,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
scale: float = 1.0,
|
||||
ngram_lm_scale: Optional[float] = None,
|
||||
attention_scale: Optional[float] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, k2.Fsa]]:
|
||||
"""This function extracts n paths from the given lattice and uses
|
||||
an attention decoder to rescore them. The path with the highest
|
||||
@ -1022,7 +1024,7 @@ def rescore_with_attention_decoder_v2(
|
||||
ngram_lm_scale_attention_scale and the value is the
|
||||
best decoding path for each sequence in the lattice.
|
||||
"""
|
||||
nbest = generate_nbest_list(lattice, num_paths)
|
||||
nbest = generate_nbest_list(lattice, num_paths, scale)
|
||||
|
||||
if dump_best_matching_feature:
|
||||
if nbest.fsa.arcs.dim0() <= 2 * top_k or nbest.fsa.arcs.num_elements() == 0:
|
||||
@ -1101,14 +1103,11 @@ def rescore_with_attention_decoder_v2(
|
||||
# tot_scores() shape [utt][path]
|
||||
# path_score with elements numbers equals numbers of paths
|
||||
# !!! Note: This is right only when utt equals to 1
|
||||
path_scores = nbest_remain.total_scores().values()
|
||||
best_score = torch.max(rescored_nbest_topk.total_scores().values())
|
||||
est_scores = 1 - 1/2 * (
|
||||
1 + torch.erf(
|
||||
(best_score - path_mean) / torch.sqrt(2 * path_var)
|
||||
)
|
||||
)
|
||||
est_scores = k2.RaggedFloat(nbest_remain.shape, est_scores)
|
||||
|
||||
est_scores = (path_mean - best_score) / torch.sqrt(path_var)
|
||||
# print (f"best score : {best_score}, est_scores : {est_scores}")
|
||||
est_scores = k2.RaggedFloat(nbest_remain.shape, -est_scores)
|
||||
|
||||
# calculate nbest_remain estimated score and select topk
|
||||
nbest_remain_topk = nbest_remain.top_k(k=top_k, scores=est_scores)
|
||||
@ -1134,9 +1133,10 @@ def rescore_with_attention_decoder_v2(
|
||||
merge_fsa, k2.ragged.create_ragged_shape2(row_ids=merge_row_ids)
|
||||
)
|
||||
|
||||
attention_scores = rescore_nbest.fsa.scores
|
||||
attention_scores = -rescore_nbest.fsa.scores
|
||||
am_scores = torch.cat((am_scores, remain_am_scores))
|
||||
lm_scores = torch.cat((lm_scores, remain_lm_scores))
|
||||
|
||||
else:
|
||||
am_scores = nbest.fsa.scores - nbest.fsa.lm_scores
|
||||
lm_scores = nbest.fsa.lm_scores
|
||||
@ -1148,13 +1148,20 @@ def rescore_with_attention_decoder_v2(
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id
|
||||
)
|
||||
attention_scores = rescore_nbest.fsa.scores
|
||||
attention_scores = -rescore_nbest.fsa.scores
|
||||
|
||||
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
if ngram_lm_scale is None:
|
||||
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
else:
|
||||
ngram_lm_scale_list = [ngram_lm_scale]
|
||||
|
||||
if attention_scale is None:
|
||||
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
else:
|
||||
attention_scale_list = [attention_scale]
|
||||
|
||||
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
|
||||
ans = dict()
|
||||
for n_scale in ngram_lm_scale_list:
|
||||
@ -1180,7 +1187,7 @@ def rescore_with_attention_decoder_v2(
|
||||
def generate_nbest_list(
|
||||
lats: k2.Fsa,
|
||||
num_paths: int,
|
||||
aux_labels: bool = False
|
||||
scale: float = 1.0
|
||||
) -> Nbest:
|
||||
'''Generate an n-best list from a lattice.
|
||||
|
||||
@ -1198,7 +1205,12 @@ def generate_nbest_list(
|
||||
'''
|
||||
# First, extract `num_paths` paths for each sequence.
|
||||
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
|
||||
path = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
|
||||
path = _get_random_paths(
|
||||
lattice=lats,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=True,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
||||
# but it contains word IDs. Note that it also contains 0s and -1s.
|
||||
@ -1269,14 +1281,14 @@ def generate_nbest_list(
|
||||
|
||||
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
||||
|
||||
# replace labels with tokens to remove repeat token IDs.
|
||||
path_lattice.labels = path_lattice.tokens
|
||||
|
||||
n_best = k2.shortest_path(path_lattice, use_double_scores=True)
|
||||
|
||||
n_best.labels = n_best.tokens
|
||||
|
||||
n_best = k2.remove_epsilon(n_best)
|
||||
|
||||
n_best = k2.top_sort(k2.connect(n_best))
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
# now we have nbest lists with am_scores and lm_scores
|
||||
return Nbest(fsa=n_best, shape=seq_to_path_shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user