diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index ce50c4cce..9519c1cf2 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -125,8 +125,8 @@ def get_parser(): def get_params() -> AttributeDict: params = AttributeDict( { - # "exp_dir": Path("exp/conformer_ctc"), - "exp_dir": Path("conformer_ctc/exp"), + "exp_dir": Path("exp/conformer_ctc"), + # "exp_dir": Path("conformer_ctc/exp"), "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), "feature_dim": 80, @@ -330,6 +330,7 @@ def decode_one_batch( rescore_est_model=rescore_est_model, sos_id=sos_id, eos_id=eos_id, + scale=params.lattice_score_scale, ) if params.dump_best_matching_feature: if best_path_dict.size()[0] > 0: @@ -612,13 +613,13 @@ def main(): # test_sets = ["test-clean", "test-other"] for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): - if test_set == "test-other": continue + # if test_set == "test-other": continue results_dict = decode_dataset( dl=test_dl, params=params, model=model, HLG=HLG, - word_table=word_table, + word_table=lexicon.word_table, G=G, sos_id=sos_id, eos_id=eos_id, diff --git a/icefall/decode.py b/icefall/decode.py index fad80c7c7..51a6fce6f 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -932,7 +932,7 @@ def rescore_nbest_with_attention_decoder( token_seq = k2.ragged.remove_values_leq(token_seq, -1) # token seq shape [utt][path][token] token_seq = k2.ragged.remove_axis(token_seq, 2) - # token seq shape [utt][token] + # token seq shape [path][token] token_seq = k2.ragged.remove_axis(token_seq, 0) token_ids = k2.ragged.to_list(token_seq) @@ -944,7 +944,6 @@ def rescore_nbest_with_attention_decoder( 0, path_to_seq_map_long ) - # TODO: pass the sos_token_id and eos_token_id via function arguments nll = model.decoder_nll( memory=expanded_memory, memory_key_padding_mask=expanded_memory_key_padding_mask, @@ -987,6 +986,9 @@ def rescore_with_attention_decoder_v2( rescore_est_model: nn.Module, sos_id: int, eos_id: int, + scale: float = 1.0, + ngram_lm_scale: Optional[float] = None, + attention_scale: Optional[float] = None, ) -> Union[torch.Tensor, Dict[str, k2.Fsa]]: """This function extracts n paths from the given lattice and uses an attention decoder to rescore them. The path with the highest @@ -1022,7 +1024,7 @@ def rescore_with_attention_decoder_v2( ngram_lm_scale_attention_scale and the value is the best decoding path for each sequence in the lattice. """ - nbest = generate_nbest_list(lattice, num_paths) + nbest = generate_nbest_list(lattice, num_paths, scale) if dump_best_matching_feature: if nbest.fsa.arcs.dim0() <= 2 * top_k or nbest.fsa.arcs.num_elements() == 0: @@ -1101,14 +1103,11 @@ def rescore_with_attention_decoder_v2( # tot_scores() shape [utt][path] # path_score with elements numbers equals numbers of paths # !!! Note: This is right only when utt equals to 1 - path_scores = nbest_remain.total_scores().values() best_score = torch.max(rescored_nbest_topk.total_scores().values()) - est_scores = 1 - 1/2 * ( - 1 + torch.erf( - (best_score - path_mean) / torch.sqrt(2 * path_var) - ) - ) - est_scores = k2.RaggedFloat(nbest_remain.shape, est_scores) + + est_scores = (path_mean - best_score) / torch.sqrt(path_var) + # print (f"best score : {best_score}, est_scores : {est_scores}") + est_scores = k2.RaggedFloat(nbest_remain.shape, -est_scores) # calculate nbest_remain estimated score and select topk nbest_remain_topk = nbest_remain.top_k(k=top_k, scores=est_scores) @@ -1134,9 +1133,10 @@ def rescore_with_attention_decoder_v2( merge_fsa, k2.ragged.create_ragged_shape2(row_ids=merge_row_ids) ) - attention_scores = rescore_nbest.fsa.scores + attention_scores = -rescore_nbest.fsa.scores am_scores = torch.cat((am_scores, remain_am_scores)) lm_scores = torch.cat((lm_scores, remain_lm_scores)) + else: am_scores = nbest.fsa.scores - nbest.fsa.lm_scores lm_scores = nbest.fsa.lm_scores @@ -1148,13 +1148,20 @@ def rescore_with_attention_decoder_v2( sos_id=sos_id, eos_id=eos_id ) - attention_scores = rescore_nbest.fsa.scores + attention_scores = -rescore_nbest.fsa.scores - ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] - ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + if ngram_lm_scale is None: + ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + else: + ngram_lm_scale_list = [ngram_lm_scale] + + if attention_scale is None: + attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + else: + attention_scale_list = [attention_scale] - attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] - attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] ans = dict() for n_scale in ngram_lm_scale_list: @@ -1180,7 +1187,7 @@ def rescore_with_attention_decoder_v2( def generate_nbest_list( lats: k2.Fsa, num_paths: int, - aux_labels: bool = False + scale: float = 1.0 ) -> Nbest: '''Generate an n-best list from a lattice. @@ -1198,7 +1205,12 @@ def generate_nbest_list( ''' # First, extract `num_paths` paths for each sequence. # path is a k2.RaggedInt with axes [seq][path][arc_pos] - path = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) + path = _get_random_paths( + lattice=lats, + num_paths=num_paths, + use_double_scores=True, + scale=scale, + ) # word_seq is a k2.RaggedInt sharing the same shape as `path` # but it contains word IDs. Note that it also contains 0s and -1s. @@ -1269,14 +1281,14 @@ def generate_nbest_list( path_lattice = k2.top_sort(k2.connect(path_lattice)) - # replace labels with tokens to remove repeat token IDs. - path_lattice.labels = path_lattice.tokens - n_best = k2.shortest_path(path_lattice, use_double_scores=True) + n_best.labels = n_best.tokens + n_best = k2.remove_epsilon(n_best) n_best = k2.top_sort(k2.connect(n_best)) + # import pdb; pdb.set_trace() # now we have nbest lists with am_scores and lm_scores return Nbest(fsa=n_best, shape=seq_to_path_shape)