From 899f8586590d2d5bf1109b16a2751104e6a0d352 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 25 May 2023 12:20:29 +0800 Subject: [PATCH] Add blank-penalty to other decoding method --- .../beam_search.py | 28 ++++++++++++++++++- .../pruned_transducer_stateless7/decode.py | 21 +++++++++++--- .../ASR/pruned_transducer_stateless7/train.py | 2 +- 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 40c1654a8..70df7bc08 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -49,6 +49,7 @@ def fast_beam_search_one_best( temperature: float = 1.0, subtract_ilme: bool = False, ilme_scale: float = 0.1, + blank_penalty: float = 0.0, return_timestamps: bool = False, ) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. @@ -92,6 +93,7 @@ def fast_beam_search_one_best( temperature=temperature, subtract_ilme=subtract_ilme, ilme_scale=ilme_scale, + blank_penalty=blank_penalty, ) best_path = one_best_decoding(lattice) @@ -114,6 +116,7 @@ def fast_beam_search_nbest_LG( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, + blank_penalty: float = 0.0, return_timestamps: bool = False, ) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. @@ -168,6 +171,7 @@ def fast_beam_search_nbest_LG( max_states=max_states, max_contexts=max_contexts, temperature=temperature, + blank_penalty=blank_penalty, ) nbest = Nbest.from_lattice( @@ -240,6 +244,7 @@ def fast_beam_search_nbest( nbest_scale: float = 0.5, use_double_scores: bool = True, temperature: float = 1.0, + blank_penalty: float = 0.0, return_timestamps: bool = False, ) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. @@ -293,6 +298,7 @@ def fast_beam_search_nbest( beam=beam, max_states=max_states, max_contexts=max_contexts, + blank_penalty=blank_penalty, temperature=temperature, ) @@ -331,6 +337,7 @@ def fast_beam_search_nbest_oracle( use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, + blank_penalty: float = 0.0, return_timestamps: bool = False, ) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. @@ -389,6 +396,7 @@ def fast_beam_search_nbest_oracle( max_states=max_states, max_contexts=max_contexts, temperature=temperature, + blank_penalty=blank_penalty, ) nbest = Nbest.from_lattice( @@ -434,6 +442,7 @@ def fast_beam_search( temperature: float = 1.0, subtract_ilme: bool = False, ilme_scale: float = 0.1, + blank_penalty: float = 0.0, ) -> k2.Fsa: """It limits the maximum number of symbols per frame to 1. @@ -503,6 +512,8 @@ def fast_beam_search( project_input=False, ) logits = logits.squeeze(1).squeeze(1) + if blank_penalty != 0: + logits[:, 0] -= blank_penalty log_probs = (logits / temperature).log_softmax(dim=-1) if subtract_ilme: ilme_logits = model.joiner( @@ -526,6 +537,7 @@ def greedy_search( model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int, + blank_penalty: float = 0.0, return_timestamps: bool = False, ) -> Union[List[int], DecodingResults]: """Greedy search for a single utterance. @@ -595,6 +607,9 @@ def greedy_search( ) # logits is (1, 1, 1, vocab_size) + if blank_penalty != 0: + logits[:,:,:,0] -= blank_penalty + y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) @@ -704,7 +719,10 @@ def greedy_search_batch( logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) assert logits.ndim == 2, logits.shape - logits[:, 0] -= blank_penalty + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + y = logits.argmax(dim=1).tolist() emitted = False for i, v in enumerate(y): @@ -921,6 +939,7 @@ def modified_beam_search( encoder_out_lens: torch.Tensor, beam: int = 4, temperature: float = 1.0, + blank_penalty: float = 0.0, return_timestamps: bool = False, ) -> Union[List[List[int]], DecodingResults]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. @@ -1024,6 +1043,9 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) @@ -1628,6 +1650,7 @@ def beam_search( encoder_out: torch.Tensor, beam: int = 4, temperature: float = 1.0, + blank_penalty: float = 0.0, return_timestamps: bool = False, ) -> Union[List[int], DecodingResults]: """ @@ -1724,6 +1747,9 @@ def beam_search( project_input=False, ) + if blank_penalty != 0: + logits[:,:,:,0] -= blank_penalty + # TODO(fangjun): Scale the blank posterior log_prob = (logits / temperature).log_softmax(dim=-1) # log_prob is (1, 1, 1, vocab_size) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py index 231cbdccd..e3931509b 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless7/decode.py @@ -307,7 +307,12 @@ def get_parser(): "--blank-penalty", type=float, default=0.0, - help="", + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, ) add_model_arguments(parser) @@ -373,6 +378,7 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + blank_penalty=params.blank_penalty, ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) @@ -387,6 +393,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in hyp_tokens: sentence = "".join([lexicon.word_table[i] for i in hyp]) @@ -402,6 +409,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) @@ -417,6 +425,7 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) @@ -431,6 +440,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, beam=params.beam_size, ) for i in range(encoder_out.size(0)): @@ -447,10 +457,12 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, ) elif params.decoding_method == "beam_search": hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size, + blank_penalty=params.blank_penalty, ) else: raise ValueError( @@ -458,10 +470,11 @@ def decode_one_batch( ) hyps.append([lexicon.token_table[idx] for idx in hyp]) + key = f"blank_penalty_{params.blank_penalty}" if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} + return {"greedy_search_" + key: hyps} elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" + key += f"_beam_{params.beam}_" key += f"max_contexts_{params.max_contexts}_" key += f"max_states_{params.max_states}" if "nbest" in params.decoding_method: @@ -472,7 +485,7 @@ def decode_one_batch( return {key: hyps} else: - return {f"beam_size_{params.beam_size}": hyps} + return {f"beam_size_{params.beam_size}_" + key: hyps} def decode_dataset( diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless7/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless7/train.py index 84dca1028..5167f66f0 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless7/train.py @@ -260,7 +260,7 @@ def get_parser(): parser.add_argument( "--lr-epochs", type=float, - default=3.5, + default=1.5, help="""Number of epochs that affects how rapidly the learning rate decreases. """, )