Fix the score of the nbest attention rescorer

This commit is contained in:
pkufool 2021-09-07 19:41:18 +08:00
parent 1ac52e5bcb
commit 355e3244c8
2 changed files with 38 additions and 25 deletions

View File

@ -125,8 +125,8 @@ def get_parser():
def get_params() -> AttributeDict:
params = AttributeDict(
{
# "exp_dir": Path("exp/conformer_ctc"),
"exp_dir": Path("conformer_ctc/exp"),
"exp_dir": Path("exp/conformer_ctc"),
# "exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"),
"feature_dim": 80,
@ -330,6 +330,7 @@ def decode_one_batch(
rescore_est_model=rescore_est_model,
sos_id=sos_id,
eos_id=eos_id,
scale=params.lattice_score_scale,
)
if params.dump_best_matching_feature:
if best_path_dict.size()[0] > 0:
@ -612,13 +613,13 @@ def main():
#
test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
if test_set == "test-other": continue
# if test_set == "test-other": continue
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
word_table=word_table,
word_table=lexicon.word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,

View File

@ -932,7 +932,7 @@ def rescore_nbest_with_attention_decoder(
token_seq = k2.ragged.remove_values_leq(token_seq, -1)
# token seq shape [utt][path][token]
token_seq = k2.ragged.remove_axis(token_seq, 2)
# token seq shape [utt][token]
# token seq shape [path][token]
token_seq = k2.ragged.remove_axis(token_seq, 0)
token_ids = k2.ragged.to_list(token_seq)
@ -944,7 +944,6 @@ def rescore_nbest_with_attention_decoder(
0, path_to_seq_map_long
)
# TODO: pass the sos_token_id and eos_token_id via function arguments
nll = model.decoder_nll(
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
@ -987,6 +986,9 @@ def rescore_with_attention_decoder_v2(
rescore_est_model: nn.Module,
sos_id: int,
eos_id: int,
scale: float = 1.0,
ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None,
) -> Union[torch.Tensor, Dict[str, k2.Fsa]]:
"""This function extracts n paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest
@ -1022,7 +1024,7 @@ def rescore_with_attention_decoder_v2(
ngram_lm_scale_attention_scale and the value is the
best decoding path for each sequence in the lattice.
"""
nbest = generate_nbest_list(lattice, num_paths)
nbest = generate_nbest_list(lattice, num_paths, scale)
if dump_best_matching_feature:
if nbest.fsa.arcs.dim0() <= 2 * top_k or nbest.fsa.arcs.num_elements() == 0:
@ -1101,14 +1103,11 @@ def rescore_with_attention_decoder_v2(
# tot_scores() shape [utt][path]
# path_score with elements numbers equals numbers of paths
# !!! Note: This is right only when utt equals to 1
path_scores = nbest_remain.total_scores().values()
best_score = torch.max(rescored_nbest_topk.total_scores().values())
est_scores = 1 - 1/2 * (
1 + torch.erf(
(best_score - path_mean) / torch.sqrt(2 * path_var)
)
)
est_scores = k2.RaggedFloat(nbest_remain.shape, est_scores)
est_scores = (path_mean - best_score) / torch.sqrt(path_var)
# print (f"best score : {best_score}, est_scores : {est_scores}")
est_scores = k2.RaggedFloat(nbest_remain.shape, -est_scores)
# calculate nbest_remain estimated score and select topk
nbest_remain_topk = nbest_remain.top_k(k=top_k, scores=est_scores)
@ -1134,9 +1133,10 @@ def rescore_with_attention_decoder_v2(
merge_fsa, k2.ragged.create_ragged_shape2(row_ids=merge_row_ids)
)
attention_scores = rescore_nbest.fsa.scores
attention_scores = -rescore_nbest.fsa.scores
am_scores = torch.cat((am_scores, remain_am_scores))
lm_scores = torch.cat((lm_scores, remain_lm_scores))
else:
am_scores = nbest.fsa.scores - nbest.fsa.lm_scores
lm_scores = nbest.fsa.lm_scores
@ -1148,13 +1148,20 @@ def rescore_with_attention_decoder_v2(
sos_id=sos_id,
eos_id=eos_id
)
attention_scores = rescore_nbest.fsa.scores
attention_scores = -rescore_nbest.fsa.scores
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
if ngram_lm_scale is None:
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
ngram_lm_scale_list = [ngram_lm_scale]
if attention_scale is None:
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
attention_scale_list = [attention_scale]
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
ans = dict()
for n_scale in ngram_lm_scale_list:
@ -1180,7 +1187,7 @@ def rescore_with_attention_decoder_v2(
def generate_nbest_list(
lats: k2.Fsa,
num_paths: int,
aux_labels: bool = False
scale: float = 1.0
) -> Nbest:
'''Generate an n-best list from a lattice.
@ -1198,7 +1205,12 @@ def generate_nbest_list(
'''
# First, extract `num_paths` paths for each sequence.
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
path = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
path = _get_random_paths(
lattice=lats,
num_paths=num_paths,
use_double_scores=True,
scale=scale,
)
# word_seq is a k2.RaggedInt sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
@ -1269,14 +1281,14 @@ def generate_nbest_list(
path_lattice = k2.top_sort(k2.connect(path_lattice))
# replace labels with tokens to remove repeat token IDs.
path_lattice.labels = path_lattice.tokens
n_best = k2.shortest_path(path_lattice, use_double_scores=True)
n_best.labels = n_best.tokens
n_best = k2.remove_epsilon(n_best)
n_best = k2.top_sort(k2.connect(n_best))
# import pdb; pdb.set_trace()
# now we have nbest lists with am_scores and lm_scores
return Nbest(fsa=n_best, shape=seq_to_path_shape)