mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Fix the score of the nbest attention rescorer
This commit is contained in:
parent
1ac52e5bcb
commit
355e3244c8
@ -125,8 +125,8 @@ def get_parser():
|
|||||||
def get_params() -> AttributeDict:
|
def get_params() -> AttributeDict:
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
# "exp_dir": Path("exp/conformer_ctc"),
|
"exp_dir": Path("exp/conformer_ctc"),
|
||||||
"exp_dir": Path("conformer_ctc/exp"),
|
# "exp_dir": Path("conformer_ctc/exp"),
|
||||||
"lang_dir": Path("data/lang_bpe"),
|
"lang_dir": Path("data/lang_bpe"),
|
||||||
"lm_dir": Path("data/lm"),
|
"lm_dir": Path("data/lm"),
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
@ -330,6 +330,7 @@ def decode_one_batch(
|
|||||||
rescore_est_model=rescore_est_model,
|
rescore_est_model=rescore_est_model,
|
||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
|
scale=params.lattice_score_scale,
|
||||||
)
|
)
|
||||||
if params.dump_best_matching_feature:
|
if params.dump_best_matching_feature:
|
||||||
if best_path_dict.size()[0] > 0:
|
if best_path_dict.size()[0] > 0:
|
||||||
@ -612,13 +613,13 @@ def main():
|
|||||||
#
|
#
|
||||||
test_sets = ["test-clean", "test-other"]
|
test_sets = ["test-clean", "test-other"]
|
||||||
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||||
if test_set == "test-other": continue
|
# if test_set == "test-other": continue
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
HLG=HLG,
|
HLG=HLG,
|
||||||
word_table=word_table,
|
word_table=lexicon.word_table,
|
||||||
G=G,
|
G=G,
|
||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id,
|
eos_id=eos_id,
|
||||||
|
@ -932,7 +932,7 @@ def rescore_nbest_with_attention_decoder(
|
|||||||
token_seq = k2.ragged.remove_values_leq(token_seq, -1)
|
token_seq = k2.ragged.remove_values_leq(token_seq, -1)
|
||||||
# token seq shape [utt][path][token]
|
# token seq shape [utt][path][token]
|
||||||
token_seq = k2.ragged.remove_axis(token_seq, 2)
|
token_seq = k2.ragged.remove_axis(token_seq, 2)
|
||||||
# token seq shape [utt][token]
|
# token seq shape [path][token]
|
||||||
token_seq = k2.ragged.remove_axis(token_seq, 0)
|
token_seq = k2.ragged.remove_axis(token_seq, 0)
|
||||||
|
|
||||||
token_ids = k2.ragged.to_list(token_seq)
|
token_ids = k2.ragged.to_list(token_seq)
|
||||||
@ -944,7 +944,6 @@ def rescore_nbest_with_attention_decoder(
|
|||||||
0, path_to_seq_map_long
|
0, path_to_seq_map_long
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: pass the sos_token_id and eos_token_id via function arguments
|
|
||||||
nll = model.decoder_nll(
|
nll = model.decoder_nll(
|
||||||
memory=expanded_memory,
|
memory=expanded_memory,
|
||||||
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||||
@ -987,6 +986,9 @@ def rescore_with_attention_decoder_v2(
|
|||||||
rescore_est_model: nn.Module,
|
rescore_est_model: nn.Module,
|
||||||
sos_id: int,
|
sos_id: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
|
scale: float = 1.0,
|
||||||
|
ngram_lm_scale: Optional[float] = None,
|
||||||
|
attention_scale: Optional[float] = None,
|
||||||
) -> Union[torch.Tensor, Dict[str, k2.Fsa]]:
|
) -> Union[torch.Tensor, Dict[str, k2.Fsa]]:
|
||||||
"""This function extracts n paths from the given lattice and uses
|
"""This function extracts n paths from the given lattice and uses
|
||||||
an attention decoder to rescore them. The path with the highest
|
an attention decoder to rescore them. The path with the highest
|
||||||
@ -1022,7 +1024,7 @@ def rescore_with_attention_decoder_v2(
|
|||||||
ngram_lm_scale_attention_scale and the value is the
|
ngram_lm_scale_attention_scale and the value is the
|
||||||
best decoding path for each sequence in the lattice.
|
best decoding path for each sequence in the lattice.
|
||||||
"""
|
"""
|
||||||
nbest = generate_nbest_list(lattice, num_paths)
|
nbest = generate_nbest_list(lattice, num_paths, scale)
|
||||||
|
|
||||||
if dump_best_matching_feature:
|
if dump_best_matching_feature:
|
||||||
if nbest.fsa.arcs.dim0() <= 2 * top_k or nbest.fsa.arcs.num_elements() == 0:
|
if nbest.fsa.arcs.dim0() <= 2 * top_k or nbest.fsa.arcs.num_elements() == 0:
|
||||||
@ -1101,14 +1103,11 @@ def rescore_with_attention_decoder_v2(
|
|||||||
# tot_scores() shape [utt][path]
|
# tot_scores() shape [utt][path]
|
||||||
# path_score with elements numbers equals numbers of paths
|
# path_score with elements numbers equals numbers of paths
|
||||||
# !!! Note: This is right only when utt equals to 1
|
# !!! Note: This is right only when utt equals to 1
|
||||||
path_scores = nbest_remain.total_scores().values()
|
|
||||||
best_score = torch.max(rescored_nbest_topk.total_scores().values())
|
best_score = torch.max(rescored_nbest_topk.total_scores().values())
|
||||||
est_scores = 1 - 1/2 * (
|
|
||||||
1 + torch.erf(
|
est_scores = (path_mean - best_score) / torch.sqrt(path_var)
|
||||||
(best_score - path_mean) / torch.sqrt(2 * path_var)
|
# print (f"best score : {best_score}, est_scores : {est_scores}")
|
||||||
)
|
est_scores = k2.RaggedFloat(nbest_remain.shape, -est_scores)
|
||||||
)
|
|
||||||
est_scores = k2.RaggedFloat(nbest_remain.shape, est_scores)
|
|
||||||
|
|
||||||
# calculate nbest_remain estimated score and select topk
|
# calculate nbest_remain estimated score and select topk
|
||||||
nbest_remain_topk = nbest_remain.top_k(k=top_k, scores=est_scores)
|
nbest_remain_topk = nbest_remain.top_k(k=top_k, scores=est_scores)
|
||||||
@ -1134,9 +1133,10 @@ def rescore_with_attention_decoder_v2(
|
|||||||
merge_fsa, k2.ragged.create_ragged_shape2(row_ids=merge_row_ids)
|
merge_fsa, k2.ragged.create_ragged_shape2(row_ids=merge_row_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
attention_scores = rescore_nbest.fsa.scores
|
attention_scores = -rescore_nbest.fsa.scores
|
||||||
am_scores = torch.cat((am_scores, remain_am_scores))
|
am_scores = torch.cat((am_scores, remain_am_scores))
|
||||||
lm_scores = torch.cat((lm_scores, remain_lm_scores))
|
lm_scores = torch.cat((lm_scores, remain_lm_scores))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
am_scores = nbest.fsa.scores - nbest.fsa.lm_scores
|
am_scores = nbest.fsa.scores - nbest.fsa.lm_scores
|
||||||
lm_scores = nbest.fsa.lm_scores
|
lm_scores = nbest.fsa.lm_scores
|
||||||
@ -1148,13 +1148,20 @@ def rescore_with_attention_decoder_v2(
|
|||||||
sos_id=sos_id,
|
sos_id=sos_id,
|
||||||
eos_id=eos_id
|
eos_id=eos_id
|
||||||
)
|
)
|
||||||
attention_scores = rescore_nbest.fsa.scores
|
attention_scores = -rescore_nbest.fsa.scores
|
||||||
|
|
||||||
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
if ngram_lm_scale is None:
|
||||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
else:
|
||||||
|
ngram_lm_scale_list = [ngram_lm_scale]
|
||||||
|
|
||||||
|
if attention_scale is None:
|
||||||
|
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
else:
|
||||||
|
attention_scale_list = [attention_scale]
|
||||||
|
|
||||||
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
|
||||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
|
||||||
|
|
||||||
ans = dict()
|
ans = dict()
|
||||||
for n_scale in ngram_lm_scale_list:
|
for n_scale in ngram_lm_scale_list:
|
||||||
@ -1180,7 +1187,7 @@ def rescore_with_attention_decoder_v2(
|
|||||||
def generate_nbest_list(
|
def generate_nbest_list(
|
||||||
lats: k2.Fsa,
|
lats: k2.Fsa,
|
||||||
num_paths: int,
|
num_paths: int,
|
||||||
aux_labels: bool = False
|
scale: float = 1.0
|
||||||
) -> Nbest:
|
) -> Nbest:
|
||||||
'''Generate an n-best list from a lattice.
|
'''Generate an n-best list from a lattice.
|
||||||
|
|
||||||
@ -1198,7 +1205,12 @@ def generate_nbest_list(
|
|||||||
'''
|
'''
|
||||||
# First, extract `num_paths` paths for each sequence.
|
# First, extract `num_paths` paths for each sequence.
|
||||||
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
|
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
|
||||||
path = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True)
|
path = _get_random_paths(
|
||||||
|
lattice=lats,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=True,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
|
||||||
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
# word_seq is a k2.RaggedInt sharing the same shape as `path`
|
||||||
# but it contains word IDs. Note that it also contains 0s and -1s.
|
# but it contains word IDs. Note that it also contains 0s and -1s.
|
||||||
@ -1269,14 +1281,14 @@ def generate_nbest_list(
|
|||||||
|
|
||||||
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
path_lattice = k2.top_sort(k2.connect(path_lattice))
|
||||||
|
|
||||||
# replace labels with tokens to remove repeat token IDs.
|
|
||||||
path_lattice.labels = path_lattice.tokens
|
|
||||||
|
|
||||||
n_best = k2.shortest_path(path_lattice, use_double_scores=True)
|
n_best = k2.shortest_path(path_lattice, use_double_scores=True)
|
||||||
|
|
||||||
|
n_best.labels = n_best.tokens
|
||||||
|
|
||||||
n_best = k2.remove_epsilon(n_best)
|
n_best = k2.remove_epsilon(n_best)
|
||||||
|
|
||||||
n_best = k2.top_sort(k2.connect(n_best))
|
n_best = k2.top_sort(k2.connect(n_best))
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
|
|
||||||
# now we have nbest lists with am_scores and lm_scores
|
# now we have nbest lists with am_scores and lm_scores
|
||||||
return Nbest(fsa=n_best, shape=seq_to_path_shape)
|
return Nbest(fsa=n_best, shape=seq_to_path_shape)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user