diff --git a/.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh b/.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh index 3efcc13e3..954312796 100755 --- a/.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh +++ b/.github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh @@ -1,8 +1,8 @@ #!/usr/bin/env bash # This script downloads the test-clean and test-other datasets -# of LibriSpeech and unzip them to the folder ~/tmp/download, -# which is cached by GitHub actions for later runs. +# of LibriSpeech and unzips them to the folder ~/tmp/download, +# which are cached by GitHub actions for later runs. # # You will find directories ~/tmp/download/LibriSpeech after running # this script. diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 3143fa077..569a5068c 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -10,6 +10,25 @@ During training, it selects either a batch from GigaSpeech with prob `giga_prob` or a batch from LibriSpeech with prob `1 - giga_prob`. All utterances within a batch come from the same dataset. +#### 2022-05-10 + +Using commit `TODO`. + +The WERs are: + +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|----------------------------------------| +| greedy search (max sym per frame 1) | 2.21 | 5.09 | --epoch 27 --avg 2 --max-duration 600 | +| greedy search (max sym per frame 1) | 2.25 | 5.02 | --epoch 27 --avg 12 --max-duration 600 | +| modified beam search | 2.19 | 5.03 | --epoch 25 --avg 6 --max-duration 600 | +| modified beam search | 2.23 | 4.94 | --epoch 27 --avg 10 --max-duration 600 | +| beam search | 2.16 | 4.95 | --epoch 25 --avg 7 --max-duration 600 | +| fast beam search | 2.21 | 4.96 | --epoch 27 --avg 10 --max-duration 600 | +| fast beam search | 2.19 | 4.97 | --epoch 27 --avg 12 --max-duration 600 | + + +#### 2022-04-29 + Using commit `ac84220de91dee10c00e8f4223287f937b1930b6`. See . diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index ce8b04afd..e49f20e6e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -19,6 +19,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional import k2 +import sentencepiece as spm import torch from model import Transducer @@ -34,10 +35,11 @@ def fast_beam_search_one_best( beam: float, max_states: int, max_contexts: int, + temperature: float = 1.0, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using modified beam search, and then + A lattice is first obtained using fast beam search, and then the shortest path within the lattice is used as the final output. Args: @@ -56,6 +58,8 @@ def fast_beam_search_one_best( Max states per stream per frame. max_contexts: Max contexts pre stream per frame. + temperature: + Softmax temperature. Returns: Return the decoded result. """ @@ -67,6 +71,7 @@ def fast_beam_search_one_best( beam=beam, max_states=max_states, max_contexts=max_contexts, + temperature=temperature, ) best_path = one_best_decoding(lattice) @@ -74,6 +79,85 @@ def fast_beam_search_one_best( return hyps +def fast_beam_search_nbest( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using fast beam search, and then + we extract `num_paths` from the lattice using k2.random_path(), + unique them, compute the total score of each path by intersecting + it with the lattice, and output the path with the largest total score. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + Returns: + Return the decoded result. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + max_indexes = nbest.tot_scores().argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + hyps = get_texts(best_path) + return hyps + + def fast_beam_search_nbest_oracle( model: Transducer, decoding_graph: k2.Fsa, @@ -86,10 +170,11 @@ def fast_beam_search_nbest_oracle( ref_texts: List[List[int]], use_double_scores: bool = True, nbest_scale: float = 0.5, + temperature: float = 1.0, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using modified beam search, and then + A lattice is first obtained using fast beam search, and then we select `num_paths` linear paths from the lattice. The path that has the minimum edit distance with the given reference transcript is used as the output. @@ -125,6 +210,8 @@ def fast_beam_search_nbest_oracle( nbest_scale: It's the scale applied to the lattice.scores. A smaller value yields more unique paths. + temperature: + Softmax temperature. Returns: Return the decoded result. @@ -137,6 +224,7 @@ def fast_beam_search_nbest_oracle( beam=beam, max_states=max_states, max_contexts=max_contexts, + temperature=temperature, ) nbest = Nbest.from_lattice( @@ -169,6 +257,158 @@ def fast_beam_search_nbest_oracle( return hyps +def fast_beam_search_with_nbest_rescoring( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + ngram_lm_scale_list: List[float], + num_paths: int, + G: k2.Fsa, + sp: spm.SentencePieceProcessor, + word_table: k2.SymbolTable, + oov_word: str = "", + use_double_scores: bool = True, + nbest_scale: float = 0.5, + temperature: float = 1.0, +) -> Dict[str, List[List[int]]]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using modified beam search, and then + the shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + ngram_lm_scale_list: + A list of floats representing LM score scales. + num_paths: + Number of paths to extract from the decoded lattice. + G: + An FsaVec containing only a single FSA. It is an n-gram LM. + sp: + The BPE model. + word_table: + The word symbol table. + oov_word: + OOV words are replaced with this word. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + temperature: + Softmax temperature. + Returns: + Return the decoded result in a dict, where the key has the form + 'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the + ngram LM scale value used during decoding, i.e., 0.1. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + am_scores = nbest.tot_scores() + + # Now we need to compute the LM scores of each path. + # (1) Get the token IDs of each Path. We assume the decoding_graph + # is an acceptor, i.e., lattice is also an acceptor + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] + + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) + tokens = tokens.remove_values_leq(0) # remove -1 and 0 + + token_list: List[List[int]] = tokens.tolist() + word_list: List[List[str]] = sp.decode(token_list) + + assert isinstance(oov_word, str), oov_word + assert oov_word in word_table, oov_word + oov_word_id = word_table[oov_word] + + word_ids_list: List[List[int]] = [] + for words in word_list: + this_word_ids = [] + for w in words: + if w in word_table: + this_word_ids.append(word_table[w]) + else: + this_word_ids.append(oov_word_id) + word_ids_list.append(this_word_ids) + + word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) + word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) + + num_unique_paths = len(word_ids_list) + + b_to_a_map = torch.zeros( + num_unique_paths, + dtype=torch.int32, + device=lattice.device, + ) + + rescored_word_fsas = k2.intersect_device( + a_fsas=G, + b_fsas=word_fsas_with_self_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True, + ret_arc_maps=False, + ) + + rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) + ngram_lm_scores = rescored_word_fsas.get_tot_scores( + use_double_scores=True, + log_semiring=False, + ) + + ans: Dict[str, List[List[int]]] = {} + for s in ngram_lm_scale_list: + key = f"ngram_lm_scale_{s}" + tot_scores = am_scores.values + s * ngram_lm_scores + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) + + ans[key] = hyps + + return ans + + def fast_beam_search( model: Transducer, decoding_graph: k2.Fsa, @@ -177,6 +417,7 @@ def fast_beam_search( beam: float, max_states: int, max_contexts: int, + temperature: float = 1.0, ) -> k2.Fsa: """It limits the maximum number of symbols per frame to 1. @@ -196,6 +437,8 @@ def fast_beam_search( Max states per stream per frame. max_contexts: Max contexts pre stream per frame. + temperature: + Softmax temperature. Returns: Return an FsaVec with axes [utt][state][arc] containing the decoded lattice. Note: When the input graph is a TrivialGraph, the returned @@ -244,7 +487,7 @@ def fast_beam_search( project_input=False, ) logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) + log_probs = (logits / temperature).log_softmax(dim=-1) decoding_streams.advance(log_probs) decoding_streams.terminate_and_flush_to_streams() lattice = decoding_streams.format_output(encoder_out_lens.tolist()) @@ -587,6 +830,7 @@ def modified_beam_search( encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, beam: int = 4, + temperature: float = 1.0, ) -> List[List[int]]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. @@ -600,6 +844,8 @@ def modified_beam_search( encoder_out before padding. beam: Number of active paths during the beam search. + temperature: + Softmax temperature. Returns: Return a list-of-list of token IDs. ans[i] is the decoding results for the i-th utterance. @@ -683,7 +929,7 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax(dim=-1) log_probs.add_(ys_log_probs) @@ -847,6 +1093,7 @@ def beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, + temperature: float = 1.0, ) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf @@ -860,6 +1107,8 @@ def beam_search( A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. beam: Beam size. + temperature: + Softmax temperature. Returns: Return the decoded result. """ @@ -936,7 +1185,7 @@ def beam_search( ) # TODO(fangjun): Scale the blank posterior - log_prob = logits.log_softmax(dim=-1) + log_prob = (logits / temperature).log_softmax(dim=-1) # log_prob is (1, 1, 1, vocab_size) log_prob = log_prob.squeeze() # Now log_prob is (vocab_size,) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 5b3dce853..2a76dd31a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -19,40 +19,67 @@ Usage: (1) greedy search ./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method greedy_search + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method greedy_search (2) beam search (not recommended) ./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 (3) modified beam search ./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 -(4) fast beam search +(4) fast beam search (one best) ./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 + +(5) fast beam search (nbest) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest with n-gram LM rescoring) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_with_nbest_rescoring \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --num-paths 200 \ + --nbest-scale 0.5 \ + --lm-dir ./data/lm """ @@ -69,8 +96,10 @@ import torch.nn as nn from asr_datamodule import AsrDataModule from beam_search import ( beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, fast_beam_search_one_best, + fast_beam_search_with_nbest_rescoring, greedy_search, greedy_search_batch, modified_beam_search, @@ -147,7 +176,9 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_nbest - fast_beam_search_nbest_oracle + - fast_beam_search_with_nbest_rescoring """, ) @@ -168,7 +199,9 @@ def get_parser(): search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. Used only when --decoding-method is - fast_beam_search or fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring + """, ) parser.add_argument( @@ -176,7 +209,9 @@ def get_parser(): type=int, default=4, help="""Used only when --decoding-method is - fast_beam_search or fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring + """, ) parser.add_argument( @@ -184,7 +219,9 @@ def get_parser(): type=int, default=8, help="""Used only when --decoding-method is - fast_beam_search or fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring + """, ) parser.add_argument( @@ -194,6 +231,7 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) + parser.add_argument( "--max-sym-per-frame", type=int, @@ -207,7 +245,8 @@ def get_parser(): type=int, default=100, help="""Number of paths for computed nbest oracle WER - when the decoding method is fast_beam_search_nbest_oracle. + when the decoding method is fast_beam_search_nbest_oracle, + fast_beam_search_nbest, or fast_beam_search_with_nbest_rescoring. """, ) @@ -216,9 +255,40 @@ def get_parser(): type=float, default=0.5, help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding_method is fast_beam_search_nbest_oracle. + Used only when the decoding_method is fast_beam_search_nbest_oracle, + fast_beam_search_nbest, or fast_beam_search_with_nbest_rescoring. """, ) + + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="""Softmax temperature. + The output of the model is (logits / temperature).log_softmax(). + """, + ) + + parser.add_argument( + "--lm-dir", + type=Path, + default=Path("./data/lm"), + help="""Used only when --decoding-method is + fast_beam_search_with_nbest_rescoring. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--words-txt", + type=Path, + default=Path("./data/lang_bpe_500/words.txt"), + help="""Used only when --decoding-method is + fast_beam_search_with_nbest_rescoring. + It is the word table. + """, + ) + return parser @@ -228,6 +298,8 @@ def decode_one_batch( sp: spm.SentencePieceProcessor, batch: dict, decoding_graph: Optional[k2.Fsa] = None, + G: Optional[k2.Fsa] = None, + word_table: Optional[k2.SymbolTable] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -252,8 +324,17 @@ def decode_one_batch( for the format of the `batch`. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is - fast_beam_search or fast_beam_search_nbest_oracle. + only when decoding method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring. + G: + Optional. Used only when decoding method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, + or fast_beam_search_with_nbest_rescoring. + It an FsaVec containing an acceptor. + word_table: + Optional. Used only when decoding method is + fast_beam_search_with_nbest_rescoring. It is the word symbol table + containing mappings between words and IDs. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -282,6 +363,22 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + temperature=params.temperature, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + temperature=params.temperature, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -297,9 +394,30 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, + temperature=params.temperature, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_with_nbest_rescoring": + ngram_lm_scale_list = [-0.3, -0.2, -0.1, -0.05, -0.02, 0] + ngram_lm_scale_list += [0.01, 0.02, 0.05] + hyp_tokens = fast_beam_search_with_nbest_rescoring( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ngram_lm_scale_list=ngram_lm_scale_list, + num_paths=params.num_paths, + G=G, + sp=sp, + word_table=word_table, + use_double_scores=True, + nbest_scale=params.nbest_scale, + temperature=params.temperature, + ) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -317,6 +435,7 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + temperature=params.temperature, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -338,6 +457,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size, + temperature=params.temperature, ) else: raise ValueError( @@ -352,7 +472,19 @@ def decode_one_batch( ( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" + f"max_states_{params.max_states}_" + f"temperature_{params.temperature}" + ): hyps + } + elif params.decoding_method == "fast_beam_search_nbest": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_" + f"num_paths_{params.num_paths}_" + f"nbest_scale_{params.nbest_scale}_" + f"temperature_{params.temperature}" ): hyps } elif params.decoding_method == "fast_beam_search_nbest_oracle": @@ -362,11 +494,31 @@ def decode_one_batch( f"max_contexts_{params.max_contexts}_" f"max_states_{params.max_states}_" f"num_paths_{params.num_paths}_" - f"nbest_scale_{params.nbest_scale}" + f"nbest_scale_{params.nbest_scale}_" + f"temperature_{params.temperature}" ): hyps } + elif params.decoding_method == "fast_beam_search_with_nbest_rescoring": + prefix = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_" + f"num_paths_{params.num_paths}_" + f"nbest_scale_{params.nbest_scale}_" + f"temperature_{params.temperature}_" + ) + ans: Dict[str, List[List[str]]] = {} + for key, hyp in hyp_tokens.items(): + t: List[str] = sp.decode(hyp) + ans[prefix + key] = [s.split() for s in t] + return ans else: - return {f"beam_size_{params.beam_size}": hyps} + return { + ( + f"beam_size_{params.beam_size}_" + f"temperature_{params.temperature}" + ): hyps + } def decode_dataset( @@ -375,6 +527,8 @@ def decode_dataset( model: nn.Module, sp: spm.SentencePieceProcessor, decoding_graph: Optional[k2.Fsa] = None, + G: Optional[k2.Fsa] = None, + word_table: Optional[k2.SymbolTable] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -389,7 +543,17 @@ def decode_dataset( The BPE model. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when decoding method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring. + G: + Optional. Used only when decoding method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_oracle, + or fast_beam_search_with_nbest_rescoring. + It an FsaVec containing an acceptor. + word_table: + Optional. Used only when decoding method is + fast_beam_search_with_nbest_rescoring. It is the word symbol table + containing mappings between words and IDs. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -419,6 +583,8 @@ def decode_dataset( sp=sp, decoding_graph=decoding_graph, batch=batch, + G=G, + word_table=word_table, ) for name, hyps in hyps_dict.items(): @@ -438,6 +604,7 @@ def decode_dataset( logging.info( f"batch {batch_str}, cuts processed until now is {num_cuts}" ) + return results @@ -485,6 +652,68 @@ def save_results( logging.info(s) +def load_ngram_LM( + lm_dir: Path, word_table: k2.SymbolTable, device: torch.device +) -> k2.Fsa: + """Read a ngram model from the given directory. + + Args: + lm_dir: + It should contain either G_4_gram.pt or G_4_gram.fst.txt + word_table: + The word table mapping words to IDs and vice versa. + device: + The resulting FSA will be moved to this device. + Returns: + Return an FsaVec containing a single acceptor. + """ + lm_dir = Path(lm_dir) + assert lm_dir.is_dir(), f"{lm_dir} does not exist" + + pt_file = lm_dir / "G_4_gram.pt" + + if pt_file.is_file(): + logging.info(f"Loading pre-compiled {pt_file}") + d = torch.load(pt_file, map_location=device) + G = k2.Fsa.from_dict(d) + return G + + txt_file = lm_dir / "G_4_gram.fst.txt" + + assert txt_file.is_file(), f"{txt_file} does not exist" + logging.info(f"Loading {txt_file}") + logging.warning("It may take 8 minutes (Will be cached for later use).") + with open(txt_file) as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # Now G is an acceptor + + first_word_disambig_id = word_table["#0"] + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + logging.info(f"Saving to {pt_file} for later use") + torch.save(G.as_dict(), pt_file) + return G + + @torch.no_grad() def main(): parser = get_parser() @@ -499,7 +728,9 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_nbest", "fast_beam_search_nbest_oracle", + "fast_beam_search_with_nbest_rescoring", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method @@ -513,19 +744,27 @@ def main(): params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" - elif params.decoding_method == "fast_beam_search_nbest_oracle": + params.suffix += f"-temperature-{params.temperature}" + elif params.decoding_method in ( + "fast_beam_search_nbest", + "fast_beam_search_nbest_oracle", + "fast_beam_search_with_nbest_rescoring", + ): params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-temperature-{params.temperature}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" ) + params.suffix += f"-temperature-{params.temperature}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-temperature-{params.temperature}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -585,12 +824,26 @@ def main(): if params.decoding_method in ( "fast_beam_search", + "fast_beam_search_nbest", "fast_beam_search_nbest_oracle", + "fast_beam_search_with_nbest_rescoring", ): decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) else: decoding_graph = None + if params.decoding_method == "fast_beam_search_with_nbest_rescoring": + logging.info(f"Loading word symbol table from {params.words_txt}") + word_table = k2.SymbolTable.from_file(params.words_txt) + G = load_ngram_LM( + lm_dir=params.lm_dir, + word_table=word_table, + device=device, + ) + else: + word_table = None + G = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -613,6 +866,8 @@ def main(): model=model, sp=sp, decoding_graph=decoding_graph, + G=G, + word_table=word_table, ) save_results( diff --git a/icefall/decode.py b/icefall/decode.py index 94f3e88ba..3ba899b4e 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -308,9 +308,7 @@ class Nbest(object): del word_fsa.aux_labels word_fsa.scores.zero_() - word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops( - word_fsa - ) + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) path_to_utt_map = self.shape.row_ids(1) @@ -609,7 +607,7 @@ def rescore_with_n_best_list( num_paths: Size of nbest list. lm_scale_list: - A list of float representing LM score scales. + A list of floats representing LM score scales. nbest_scale: Scale to be applied to ``lattice.score`` when sampling paths using ``k2.random_paths``.