diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 7aec25a84..efd5b1c66 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -6,6 +6,7 @@ import argparse import logging +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -23,10 +24,12 @@ from icefall.decode import ( nbest_decoding, one_best_decoding, rescore_with_attention_decoder, + rescore_with_attention_decoder_v2, rescore_with_n_best_list, rescore_with_whole_lattice, ) from icefall.lexicon import Lexicon +from icefall.score_estimator import ScoreEstimator from icefall.utils import ( AttributeDict, get_texts, @@ -62,6 +65,7 @@ def get_parser(): def get_params() -> AttributeDict: params = AttributeDict( { + # "exp_dir": Path("exp/conformer_ctc"), "exp_dir": Path("conformer_ctc/exp"), "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), @@ -86,10 +90,17 @@ def get_params() -> AttributeDict: # - whole-lattice-rescoring # - attention-decoder # "method": "whole-lattice-rescoring", - "method": "attention-decoder", + "method": "attention-decoder-v2", + # "method": "nbest-rescoring", + # "method": "attention-decoder", # num_paths is used when method is "nbest", "nbest-rescoring", # and attention-decoder "num_paths": 100, + # top_k is used when method is "attention-decoder-v2" + "top_k" : 10, + # dump_best_matching_feature is used when method is + # "attention-decoder-v2" to dump feature to train a special model + "dump_best_matching_feature": False, } ) return params @@ -104,6 +115,7 @@ def decode_one_batch( lexicon: Lexicon, sos_id: int, eos_id: int, + rescore_est_model: nn.Module, G: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[int]]]: """Decode one batch and return the result in a dict. The dict has the @@ -135,12 +147,16 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + batch_idx: + The batch index of current batch. lexicon: It contains word symbol table. sos_id: The token ID of the SOS. eos_id: The token ID of the EOS. + rescore_est_model: + The model to estimate rescore mean and variance, only for attention-decoder-v2 G: An LM. It is not None when params.method is "nbest-rescoring" or "whole-lattice-rescoring". In general, the G in HLG @@ -242,15 +258,24 @@ def decode_one_batch( best_path_dict = rescore_with_attention_decoder_v2( lattice=rescored_lattice, batch_idx=batch_idx, - dump_best_matching_feature=params.dump_feature, + dump_best_matching_feature=params.dump_best_matching_feature, num_paths=params.num_paths, top_k=params.top_k, model=model, memory=memory, memory_key_padding_mask=memory_key_padding_mask, + rescore_est_model=rescore_est_model, sos_id=sos_id, eos_id=eos_id, ) + if params.dump_best_matching_feature: + if best_path_dict.size()[0] > 0: + save_dir = params.exp_dir / f"rescore/feat" + if not os.path.exists(save_dir): + os.makedirs(save_dir) + file_name = save_dir / f"feats-epoch-{batch_idx}.pt" + torch.save(best_path_dict, file_name) + return dict() else: assert False, f"Unsupported decoding method: {params.method}" @@ -270,6 +295,7 @@ def decode_dataset( lexicon: Lexicon, sos_id: int, eos_id: int, + rescore_est_model: nn.Module, G: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[int], List[int]]]]: """Decode dataset. @@ -289,6 +315,8 @@ def decode_dataset( The token ID for SOS. eos_id: The token ID for EOS. + rescore_est_model: + The model to estimate rescore mean and variance, only for attention-decoder-v2 G: An LM. It is not None when params.method is "nbest-rescoring" or "whole-lattice-rescoring". In general, the G in HLG @@ -303,7 +331,7 @@ def decode_dataset( results = [] num_cuts = 0 - tot_num_cuts = len(dl.dataset.cuts) + tot_batches = len(dl) results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -314,11 +342,12 @@ def decode_dataset( model=model, HLG=HLG, batch=batch, - batch_idx, + batch_idx=batch_idx, lexicon=lexicon, G=G, sos_id=sos_id, eos_id=eos_id, + rescore_est_model=rescore_est_model, ) for lm_scale, hyps in hyps_dict.items(): @@ -334,9 +363,8 @@ def decode_dataset( if batch_idx % 100 == 0: logging.info( - f"batch {batch_idx}, cuts processed until now is " - f"{num_cuts}/{tot_num_cuts} " - f"({float(num_cuts)/tot_num_cuts*100:.6f}%)" + f"batch {batch_idx}/{tot_batches}, cuts processed until now is " + f"{num_cuts}" ) return results @@ -430,6 +458,7 @@ def main(): "nbest-rescoring", "whole-lattice-rescoring", "attention-decoder", + "attention-decoder-v2", ): if not (params.lm_dir / "G_4_gram.pt").is_file(): logging.info("Loading G_4_gram.fst.txt") @@ -453,7 +482,7 @@ def main(): d = torch.load(params.lm_dir / "G_4_gram.pt") G = k2.Fsa.from_dict(d).to(device) - if params.method in ["whole-lattice-rescoring", "attention-decoder"]: + if params.method in ["whole-lattice-rescoring", "attention-decoder", "attention-decoder-v2"]: # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = k2.add_epsilon_self_loops(G) @@ -465,6 +494,15 @@ def main(): G.lm_scores = G.scores.clone() else: G = None + if params.method == "attention-decoder-v2": + rescore_est_model = ScoreEstimator() + rescore_est_model.load_state_dict( + torch.load(f"{params.exp_dir}/rescore/epoch-19.pt", + map_location="cpu") + ) + rescore_est_model.to(device) + else: + rescore_est_model = None model = Conformer( num_features=params.feature_dim, @@ -504,6 +542,7 @@ def main(): # test_sets = ["test-clean", "test-other"] for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): + if test_set == "test-other": continue results_dict = decode_dataset( dl=test_dl, params=params, @@ -513,6 +552,7 @@ def main(): G=G, sos_id=sos_id, eos_id=eos_id, + rescore_est_model=rescore_est_model, ) save_results( diff --git a/icefall/decode.py b/icefall/decode.py index c7d3a86e1..ddca95749 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1,10 +1,15 @@ import logging +import os from typing import Dict, List, Optional, Tuple, Union import k2 import torch import torch.nn as nn +from .nbest import Nbest +from .utils import get_best_matching_stats + +from .score_estimator import ScoreEstimator def _intersect_device( a_fsas: k2.Fsa, @@ -752,20 +757,25 @@ def rescore_nbest_with_attention_decoder( eos_id: The token ID for EOS. Returns: - A dict of FsaVec, whose key contains a string - ngram_lm_scale_attention_scale and the value is the - best decoding path for each sequence in the lattice. + A Nbest with all of the scores on fsa arcs updated with attention scores. """ - num_seqs = nbest.shape.Dim0() - token_seq = k2.RaggedInt(nbest.shape, nbest.fsas.labels().contiguous()) + num_paths = nbest.shape.num_elements() + # token shape [utt][path][state][arc] + token_shape = k2.ragged.compose_ragged_shapes(nbest.shape, nbest.fsa.arcs.shape()) + + token_seq = k2.RaggedInt(token_shape, nbest.fsa.labels.contiguous()) # Remove -1 from token_seq, there is no epsilon tokens in token_seq, we # removed it when generating nbest list token_seq = k2.ragged.remove_values_leq(token_seq, -1) + # token seq shape [utt][path][token] + token_seq = k2.ragged.remove_axis(token_seq, 2) + # token seq shape [utt][token] + token_seq = k2.ragged.remove_axis(token_seq, 0) token_ids = k2.ragged.to_list(token_seq) - path_to_seq_map_long = token_seq.shape.row_ids(1).to(torch.long) + path_to_seq_map_long = nbest.shape.row_ids(1).to(torch.long) expanded_memory = memory.index_select(1, path_to_seq_map_long) expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( @@ -780,25 +790,27 @@ def rescore_nbest_with_attention_decoder( sos_id=sos_id, eos_id=eos_id, ) + assert nll.ndim == 2 - assert nll.shape[0] == num_seqs + assert nll.shape[0] == num_paths attention_scores = torch.zeros( - nbest.fsas.labels().size()[0], + nbest.fsa.scores.size()[0], dtype=torch.float32, - device=nbest.device + device=nbest.fsa.device ) + start_index = 0 - for i in range(num_seqs): + for i in range(num_paths): # Plus 1 to fill the score of final arc - tokens_num = len(tokens_ids[i]) + 1 - attention_scores[start_index: start_index + tokens_num] = + tokens_num = 0 if len(token_ids[i]) == 0 else len(token_ids[i]) + 1 + attention_scores[start_index: start_index + tokens_num] =\ nll[i][0: tokens_num] start_index += tokens_num - fsas = nbest.fsas.clone() - fsas.score = attention_scores - return Nbest(fsas, nbest.shape.clone()) + fsas = nbest.fsa.clone() + fsas.scores = attention_scores + return Nbest(fsas, nbest.shape) def rescore_with_attention_decoder_v2( @@ -810,9 +822,10 @@ def rescore_with_attention_decoder_v2( model: nn.Module, memory: torch.Tensor, memory_key_padding_mask: torch.Tensor, + rescore_est_model: nn.Module, sos_id: int, eos_id: int, -) -> Dict[str, k2.Fsa]: +) -> Union[torch.Tensor, Dict[str, k2.Fsa]]: """This function extracts n paths from the given lattice and uses an attention decoder to rescore them. The path with the highest score is used as the decoding output. @@ -820,6 +833,11 @@ def rescore_with_attention_decoder_v2( Args: lattice: An FsaVec. It can be the return value of :func:`get_lattice`. + batch_idx: + The batch index currently processed. + dump_best_matching_feature: + Whether to dump best matching feature, only for preparing training + data for attention-decoder-v2 num_paths: Number of paths to extract from the given lattice for rescoring. model: @@ -831,6 +849,8 @@ def rescore_with_attention_decoder_v2( Its shape is `[T, N, C]`. memory_key_padding_mask: The padding mask for memory with shape [N, T]. + rescore_est_model: + The model to estimate rescore mean and variance, only for attention-decoder-v2 sos_id: The token ID for SOS. eos_id: @@ -841,23 +861,24 @@ def rescore_with_attention_decoder_v2( best decoding path for each sequence in the lattice. """ nbest = generate_nbest_list(lattice, num_paths) - # Now we have nbest with scores - nbest = nbest.intersect(lattice) if dump_best_matching_feature: + if nbest.fsa.arcs.dim0() <= 2 * top_k or nbest.fsa.arcs.num_elements() == 0: + return torch.empty(0) nbest_k, nbest_q = nbest.split(k=top_k, sort=False) + rescored_nbest_k = rescore_nbest_with_attention_decoder( nbest=nbest_k, model=model, memory=memory, memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, - eos_id=eos_id, + eos_id=eos_id ) stats_tensor = get_best_matching_stats( rescored_nbest_k, nbest_q, - max_order=3 + max_order=5 ) rescored_nbest_q = rescore_nbest_with_attention_decoder( nbest=nbest_q, @@ -865,41 +886,132 @@ def rescore_with_attention_decoder_v2( memory=memory, memory_key_padding_mask=memory_key_padding_mask, sos_id=sos_id, + eos_id=eos_id + ) + merge_tensor = torch.cat( + (stats_tensor, rescored_nbest_q.fsa.scores.clone().view(-1, 1)), + dim=1 + ) + return merge_tensor + + if nbest.fsa.arcs.dim0() >= 2 * top_k and nbest.fsa.arcs.num_elements() != 0: + nbest_topk, nbest_remain = nbest.split(k=top_k) + + am_scores = nbest_topk.fsa.scores - nbest_topk.fsa.lm_scores + lm_scores = nbest_topk.fsa.lm_scores + + rescored_nbest_topk = rescore_nbest_with_attention_decoder( + nbest=nbest_topk, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, eos_id=eos_id, - # return feature & label or dump to file + ) - nbest_topk, nbest_remain = nbest.split(k=top_k) + stats_tensor = get_best_matching_stats( + rescored_nbest_topk, + nbest_remain, + max_order=5 + ) - rescored_nbest_topk = rescore_nbest_with_attention_decoder( - nbest=nbest_topk, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=sos_id, - eos_id=eos_id, - ) - stats_tensor = get_best_matching_stats( - rescored_nbest_topk, - nbest_remain, - max_order=3 - ) - # run rescore estimation model to get the mean and var of each token - mean, var = rescore_est_model(stats_tensor) - # calculate nbest_remain estimated score and select topk - nbest_remain_topk = nbest_remain.top_k(k=top_k) - rescored_nbest_remain_topk = rescore_nbest_with_attention_decoder( - nbest=nbest_remain_topk, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=sos_id, - eos_id=eos_id, - ) - best_path_dict=get_best_path_from_nbests( - rescored_nbest_topk, - rescored_nbest_remain_topk, - ) + # run rescore estimation model to get the mean and var of each token + mean, var = rescore_est_model(stats_tensor) + # mean_shape [utt][path][state][arcs] + mean_shape = k2.ragged.compose_ragged_shapes( + nbest_remain.shape, nbest_remain.fsa.arcs.shape()) + # mean_shape [utt][path][arcs] + mean_shape = k2.ragged.remove_axis(mean_shape, 2) + ragged_mean = k2.RaggedFloat(mean_shape, mean.contiguous()) + # path mean shape [utt][path] + path_mean = k2.ragged.sum_per_sublist(ragged_mean) + + # var_shape [utt][path][state][arcs] + var_shape = k2.ragged.compose_ragged_shapes( + nbest_remain.shape, nbest_remain.fsa.arcs.shape()) + # var_shape [utt][path][arcs] + var_shape = k2.ragged.remove_axis(var_shape, 2) + ragged_var = k2.RaggedFloat(var_shape, var.contiguous()) + # path var shape [utt][path] + path_var = k2.ragged.sum_per_sublist(ragged_var) + + # tot_scores() shape [utt][path] + # path_score with elements numbers equals numbers of paths + # !!! Note: This is right only when utt equals to 1 + path_scores = nbest_remain.total_scores().values() + best_score = torch.max(rescored_nbest_topk.total_scores().values()) + est_scores = 1 - 1/2 * ( + 1 + torch.erf( + (best_score - path_mean) / torch.sqrt(2 * path_var) + ) + ) + est_scores = k2.RaggedFloat(nbest_remain.shape, est_scores) + + # calculate nbest_remain estimated score and select topk + nbest_remain_topk = nbest_remain.top_k(k=top_k, scores=est_scores) + remain_am_scores = nbest_remain_topk.fsa.scores - nbest_remain_topk.fsa.lm_scores + remain_lm_scores = nbest_remain_topk.fsa.lm_scores + rescored_nbest_remain_topk = rescore_nbest_with_attention_decoder( + nbest=nbest_remain_topk, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + ) + + # !!! Note: This is right only when utt equals to 1 + merge_fsa = k2.cat([rescored_nbest_topk.fsa, rescored_nbest_remain_topk.fsa]) + merge_row_ids = torch.zeros( + merge_fsa.arcs.dim0(), + dtype=torch.int32, + device=merge_fsa.device + ) + rescore_nbest = Nbest( + merge_fsa, k2.ragged.create_ragged_shape2(row_ids=merge_row_ids) + ) + + attention_scores = rescore_nbest.fsa.scores + am_scores = torch.cat((am_scores, remain_am_scores)) + lm_scores = torch.cat((lm_scores, remain_lm_scores)) + else: + am_scores = nbest.fsa.scores - nbest.fsa.lm_scores + lm_scores = nbest.fsa.lm_scores + rescore_nbest = rescore_nbest_with_attention_decoder( + nbest=nbest, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id + ) + attention_scores = rescore_nbest.fsa.scores + + ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + + attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + + ans = dict() + for n_scale in ngram_lm_scale_list: + for a_scale in attention_scale_list: + tot_scores = ( + am_scores + + n_scale * lm_scores + + a_scale * attention_scores + ) + rescore_nbest.fsa.scores = tot_scores + # ragged tot scores shape [utt][path] + ragged_tot_scores = rescore_nbest.total_scores() + + argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + + best_fsas = k2.index_fsa(rescore_nbest.fsa, argmax_indexes) + + key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}" + ans[key] = best_fsas return ans @@ -920,49 +1032,90 @@ def generate_nbest_list( that represent the same word sequences, the number of paths in different sequences may not be equal. Return: - Return an Nbest object. Note the returned FSAs don't have epsilon - self-loops. + Return an Nbest object. ''' - assert len(lats.shape) == 3 - # First, extract `num_paths` paths for each sequence. - # paths is a k2.RaggedInt with axes [seq][path][arc_pos] - paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) + # path is a k2.RaggedInt with axes [seq][path][arc_pos] + path = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) - # Seqs is a k2.RaggedInt sharing the same shape as `paths`. - # Note that it also contains 0s and -1s. + # word_seq is a k2.RaggedInt sharing the same shape as `path` + # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. - # Its axes are [seq][path][word_id] - if aux_labels: - # if aux_labels enable, seqs contains word_id - assert hasattr(lats, "aux_labels") - seqs = k2.index(lats.aux_labels, paths) - else: - # CAUTION: We use `phones` instead of `tokens` here because - # :func:`compile_HLG` uses `phones` - # - # Note: compile_HLG is from k2-fsa/snowfall - assert hasattr(lats, 'phones') + word_seq = k2.index(lats.aux_labels, path) - assert not hasattr(lats, 'tokens') - lats.tokens = lats.phones - seqs = k2.index(lats.tokens, paths) + # Remove epsilons and -1 from word_seq + word_seq = k2.ragged.remove_values_leq(word_seq, 0) - # Remove epsilons (0s) and -1 from word_seqs - seqs = k2.ragged.remove_values_leq(seqs, 0) + # Remove paths that has identical word sequences. + # + # unique_word_seq is still a k2.RaggedInt with 3 axes [seq][path][word] + # except that there are no repeated paths with the same word_seq + # within a sequence. + # + # num_repeats is also a k2.RaggedInt with 2 axes containing the + # multiplicities of each path. + # num_repeats.num_elements() == unique_word_seqs.num_elements() + # + # Since k2.ragged.unique_sequences will reorder paths within a seq, + # `new2old` is a 1-D torch.Tensor mapping from the output path index + # to the input path index. + # new2old.numel() == unique_word_seqs.tot_size(1) + unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences( + word_seq, need_num_repeats=True, need_new2old_indexes=True + ) - # unique_word_seqs is still a k2.RaggedInt with axes [seq][path][word_id]. - # But then number of pathsin each sequence may be different. - unique_seqs, _, _ = k2.ragged.unique_sequences( - seqs, need_num_repeats=False, need_new2old_indexes=False) + seq_to_path_shape = k2.ragged.get_layer(unique_word_seq.shape(), 0) - seq_to_path_shape = k2.ragged.get_layer(unique_seqs.shape(), 0) + # path_to_seq_map is a 1-D torch.Tensor. + # path_to_seq_map[i] is the seq to which the i-th path + # belongs. + path_to_seq_map = seq_to_path_shape.row_ids(1) # Remove the seq axis. - # Now unique_word_seqs has only two axes [path][word_id] - unique_seqs = k2.ragged.remove_axis(unique_seqs, 0) + # Now unique_word_seq has only two axes [path][word] + unique_word_seq = k2.ragged.remove_axis(unique_word_seq, 0) - fsas = k2.linear_fsa(unique_seqs) + # word_fsa is an FsaVec with axes [path][state][arc] + word_fsa = k2.linear_fsa(unique_word_seq) - return Nbest(fsa=fsas, shape=seq_to_path_shape) + word_fsa_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsa) + + # k2.compose() currently does not support b_to_a_map. To void + # replicating `lats`, we use k2.intersect_device here. + # + # lattice has token IDs as `labels` and word IDs as aux_labels, so we + # need to invert it here. + inv_lattice = k2.invert(lats) + + # Now the `labels` of inv_lattice are word IDs (a 1-D torch.Tensor) + # and its `aux_labels` are token IDs ( a k2.RaggedInt with 2 axes) + + # Remove its `aux_labels` since it is not needed in the + # following computation + # del inv_lattice.aux_labels + inv_lattice = k2.arc_sort(inv_lattice) + + path_lattice = _intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_seq_map, + sorted_match_a=True, + ) + + # path_lattice now has token IDs as `labels` and word IDS as aux_labels. + path_lattice = k2.invert(path_lattice) + + path_lattice = k2.top_sort(k2.connect(path_lattice)) + + # replace labels with tokens to remove repeat token IDs. + path_lattice.labels = path_lattice.tokens + + n_best = k2.shortest_path(path_lattice, use_double_scores=True) + + n_best = k2.remove_epsilon(n_best) + + n_best = k2.top_sort(k2.connect(n_best)) + + # now we have nbest lists with am_scores and lm_scores + return Nbest(fsa=n_best, shape=seq_to_path_shape) diff --git a/icefall/nbest.py b/icefall/nbest.py index 14d44e227..cf2efbda0 100644 --- a/icefall/nbest.py +++ b/icefall/nbest.py @@ -82,8 +82,11 @@ class Nbest(object): one_best = k2.remove_epsilon(one_best) + one_best = k2.top_sort(k2.connect(one_best)) + return Nbest(fsa=one_best, shape=self.shape) + def total_scores(self) -> k2.RaggedFloat: '''Get total scores of the FSAs in this Nbest. @@ -100,7 +103,7 @@ class Nbest(object): # If k2.RaggedDouble is wrapped, we can use double precision here. return k2.RaggedFloat(self.shape, scores.float()) - def top_k(self, k: int) -> 'Nbest': + def top_k(self, k: int, scores: k2.RaggedFloat = None) -> 'Nbest': '''Get a subset of paths in the Nbest. The resulting Nbest is regular in that each sequence (i.e., utterance) has the same number of paths (k). @@ -113,10 +116,14 @@ class Nbest(object): Args: k: Number of paths in each utterance. + scores: + The scores using to select top-k. Returns: Return a new Nbest with a regular shape. ''' - ragged_scores = self.total_scores() + ragged_scores = scores + if ragged_scores is None: + ragged_scores = self.total_scores() # indexes contains idx01's for self.shape # ragged_scores.values()[indexes] is sorted @@ -140,6 +147,7 @@ class Nbest(object): top_k_shape = k2.ragged.regular_ragged_shape(dim0=self.shape.dim0(), dim1=k) + top_k_shape = top_k_shape.to(top_k_fsas.device) return Nbest(top_k_fsas, top_k_shape) @@ -163,7 +171,7 @@ class Nbest(object): # indexes contains idx01's for self.shape indexes = torch.arange( self.shape.num_elements(), dtype=torch.int32, - device=self.shape.device + device=self.fsa.device ) if sort: @@ -176,9 +184,12 @@ class Nbest(object): ragged_indexes = k2.RaggedInt(self.shape, indexes) - padded_indexes = k2.ragged.pad(ragged_indexes, value=-1) + padded_indexes = k2.ragged.pad(ragged_indexes, + value=-1) # Select the idx01's of top-k paths of each utterance + max_num_fsa = padded_indexes.size()[1] + first_indexes = padded_indexes[:, :k].flatten().contiguous() # Remove the padding elements diff --git a/icefall/score_estimator.py b/icefall/score_estimator.py new file mode 100644 index 000000000..f84414a7a --- /dev/null +++ b/icefall/score_estimator.py @@ -0,0 +1,188 @@ +import argparse +import glob +import logging +from pathlib import Path +from typing import Tuple, List + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +from icefall.utils import ( + setup_logger, + str2bool, +) + + +class Dataset(torch.utils.data.Dataset): + def __init__( + self, + path: Path, + model: str, + ) -> None: + super().__init__() + files = sorted(glob.glob(f"{path}/*.pt")) + if model == 'train': + self.files = files[0: int(len(files) * 0.8)] + elif model == 'dev': + self.files = files[int(len(files) * 0.8): int(len(files) * 0.9)] + elif mode == 'test': + self.files = files[int(len(files) * 0.9):] + + def __getitem__(self, index) -> torch.Tensor: + return torch.load(self.files[index]) + + def __len__(self) -> int: + return len(self.files) + + +class DatasetCollateFunc: + def __call__(self, batch: List) -> Tuple[torch.Tensor, torch.Tensor]: + x = torch.cat(batch) + return (x[:, 0:5], x[:, 5]) + + +class ScoreEstimator(nn.Module): + def __init__( + self, + input_dim: int = 5, + hidden_dim: int = 20, + ) -> None: + super().__init__() + self.embedding = nn.Linear( + in_features=input_dim, + out_features=hidden_dim + ) + self.output = nn.Linear( + in_features=hidden_dim, + out_features=2 + ) + self.sigmod = nn.Sigmoid() + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.embedding(x) + x = self.sigmod(x) + x = self.output(x) + mean, var = x[:, 0], x[:, 1] + var = torch.exp(var) + return mean, var + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--input-dim", + type=int, + default=5, + help="Dim of input feature.", + ) + + parser.add_argument( + "--hidden-dim", + type=int, + default=20, + help="Neural number of didden layer.", + ) + + parser.add_argument( + "--batch-size", + type=int, + default=10, + help="Batch size of dataloader.", + ) + + parser.add_argument( + "--epoch", + type=int, + default=20, + help="Training epochs", + ) + + parser.add_argument( + "--learning-rate", + type=float, + default=1e-4, + help="Learning rate.", + ) + + parser.add_argument( + "--exp_dir", + type=Path, + default=Path("conformer_ctc/exp"), + help="Directory to store experiment data.", + ) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + setup_logger(f"{args.exp_dir}/rescore/log") + + model = ScoreEstimator( + input_dim = args.input_dim, + hidden_dim = args.hidden_dim + ) + + model = model.to("cuda") + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + loss_fn = nn.GaussianNLLLoss() + + train_dataloader = DataLoader( + Dataset(f"{args.exp_dir}/rescore/feat", "train"), + collate_fn=DatasetCollateFunc(), + batch_size=args.batch_size, + shuffle=True + ) + dev_dataloader = DataLoader( + Dataset(f"{args.exp_dir}/rescore/feat", "dev"), + collate_fn=DatasetCollateFunc(), + batch_size=args.batch_size, + shuffle=True + ) + + for epoch in range(args.epoch): + model.train() + training_loss = 0.0 + step = 0 + for x, y in train_dataloader: + mean, var = model(x.cuda()) + loss = loss_fn(mean, y, var) + optimizer.zero_grad() + loss.backward() + optimizer.step() + training_loss += loss.item() + step += len(y) + training_loss /= step + + dev_loss = 0.0 + step = 0 + model.eval() + for x, y in dev_dataloader: + mean, var = model(x.cuda()) + loss = loss_fn(mean, y, var) + dev_loss += loss.item() + step += len(y) + dev_loss /= step + + logging.info(f"Epoch {epoch} : training loss : {training_loss}, " + f"dev loss : {dev_loss}" + ) + torch.save( + model.state_dict(), + f"{args.exp_dir}/rescore/epoch-{epoch}.pt" + ) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() + diff --git a/icefall/utils.py b/icefall/utils.py index 1bf0f88de..2f414b2ba 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -411,6 +411,10 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest, assert keys.shape.dim0() == queries.shape.dim0(), \ f'Utterances number in keys and queries should be equal : \ {keys.shape.dim0()} vs {queries.shape.dim0()}' + assert keys.fsa.device == queries.fsa.device, \ + f'Device of keys and queries should be equal : \ + {keys.fsa.device} vs {queries.fsa.device}' + device = keys.fsa.device # keys_tokens_shape [utt][path][token] keys_tokens_shape = k2.ragged.compose_ragged_shapes(keys.shape, @@ -430,11 +434,13 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest, # counts on key positions are ones keys_counts = k2.RaggedInt(keys_tokens_shape, torch.ones(keys_token_num, - dtype=torch.int32)) + dtype=torch.int32, + device=device)) # counts on query positions are zeros queries_counts = k2.RaggedInt(queries_tokens_shape, torch.zeros(queries_tokens_num, - dtype=torch.int32)) + dtype=torch.int32, + device=device)) counts = k2.ragged.cat([keys_counts, queries_counts], axis=1).values() # scores on key positions are the scores inherted from nbest path @@ -442,7 +448,8 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest, # scores on query positions MUST be zeros queries_scores = k2.RaggedFloat(queries_tokens_shape, torch.zeros(queries_tokens_num, - dtype=torch.float32)) + dtype=torch.float32, + device=device)) scores = k2.ragged.cat([keys_scores, queries_scores], axis=1).values() # we didn't remove -1 labels before @@ -450,8 +457,16 @@ def get_best_matching_stats(keys: Nbest, queries: Nbest, eos = -1 max_token = torch.max(torch.max(keys.fsa.labels), torch.max(queries.fsa.labels)) - mean, var, counts_out, ngram = k2.get_best_matching_stats(tokens, scores, - counts, eos, min_token, max_token, max_order) + mean, var, counts_out, ngram = k2.get_best_matching_stats( + tokens.to(torch.device('cpu')), scores.to(torch.device('cpu')), + counts.to(torch.device('cpu')), + eos, min_token, max_token, max_order + ) + + mean = mean.to(device) + var = var.to(device) + counts_out = counts_out.to(device) + ngram = ngram.to(device) queries_init_scores = queries.fsa.scores.clone() # only return the stats on query positions