From 4325bb20b95bbc8621684c24fcaa1986e7cd553e Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Thu, 22 Jun 2023 02:30:53 -0400 Subject: [PATCH] update for new AsrModel --- .../ASR/pruned_transducer_stateless2/beam_search.py | 13 ++++++++++++- egs/tedlium3/ASR/zipformer/decode.py | 4 ++++ egs/tedlium3/ASR/zipformer/train.py | 5 ++--- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 1bbad6946..ca19ea0ce 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -50,6 +50,7 @@ def fast_beam_search_one_best( subtract_ilme: bool = False, ilme_scale: float = 0.1, return_timestamps: bool = False, + allow_partial: bool = False, ) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. @@ -92,6 +93,7 @@ def fast_beam_search_one_best( temperature=temperature, subtract_ilme=subtract_ilme, ilme_scale=ilme_scale, + allow_partial=allow_partial, ) best_path = one_best_decoding(lattice) @@ -115,6 +117,7 @@ def fast_beam_search_nbest_LG( use_double_scores: bool = True, temperature: float = 1.0, return_timestamps: bool = False, + allow_partial: bool = False, ) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. @@ -168,6 +171,7 @@ def fast_beam_search_nbest_LG( max_states=max_states, max_contexts=max_contexts, temperature=temperature, + allow_partial=allow_partial, ) nbest = Nbest.from_lattice( @@ -241,6 +245,7 @@ def fast_beam_search_nbest( use_double_scores: bool = True, temperature: float = 1.0, return_timestamps: bool = False, + allow_partial: bool = False, ) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. @@ -294,6 +299,7 @@ def fast_beam_search_nbest( max_states=max_states, max_contexts=max_contexts, temperature=temperature, + allow_partial=allow_partial, ) nbest = Nbest.from_lattice( @@ -332,6 +338,7 @@ def fast_beam_search_nbest_oracle( nbest_scale: float = 0.5, temperature: float = 1.0, return_timestamps: bool = False, + allow_partial: bool = False, ) -> Union[List[List[int]], DecodingResults]: """It limits the maximum number of symbols per frame to 1. @@ -389,6 +396,7 @@ def fast_beam_search_nbest_oracle( max_states=max_states, max_contexts=max_contexts, temperature=temperature, + allow_partial=allow_partial, ) nbest = Nbest.from_lattice( @@ -434,6 +442,7 @@ def fast_beam_search( temperature: float = 1.0, subtract_ilme: bool = False, ilme_scale: float = 0.1, + allow_partial: bool = False, ) -> k2.Fsa: """It limits the maximum number of symbols per frame to 1. @@ -517,7 +526,9 @@ def fast_beam_search( log_probs -= ilme_scale * ilme_log_probs decoding_streams.advance(log_probs) decoding_streams.terminate_and_flush_to_streams() - lattice = decoding_streams.format_output(encoder_out_lens.tolist()) + lattice = decoding_streams.format_output( + encoder_out_lens.tolist(), allow_partial=allow_partial + ) return lattice diff --git a/egs/tedlium3/ASR/zipformer/decode.py b/egs/tedlium3/ASR/zipformer/decode.py index 6f109780a..ea1cbba1b 100755 --- a/egs/tedlium3/ASR/zipformer/decode.py +++ b/egs/tedlium3/ASR/zipformer/decode.py @@ -385,6 +385,7 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + allow_partial=True, ) for hyp in sp.decode(hyp_tokens): hyp = [w for w in hyp.split() if w != unk] @@ -400,6 +401,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + allow_partial=True, ) for hyp in hyp_tokens: hyp = [word_table[i] for i in hyp if word_table[i] != unk] @@ -415,6 +417,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + allow_partial=True, ) for hyp in sp.decode(hyp_tokens): hyp = [w for w in hyp.split() if w != unk] @@ -431,6 +434,7 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, + allow_partial=True, ) for hyp in sp.decode(hyp_tokens): hyp = [w for w in hyp.split() if w != unk] diff --git a/egs/tedlium3/ASR/zipformer/train.py b/egs/tedlium3/ASR/zipformer/train.py index 2217d4cb6..64938e24f 100755 --- a/egs/tedlium3/ASR/zipformer/train.py +++ b/egs/tedlium3/ASR/zipformer/train.py @@ -68,7 +68,7 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids -from model import Transducer +from model import AsrModel from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling @@ -585,14 +585,13 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: decoder = get_decoder_model(params) joiner = get_joiner_model(params) - model = Transducer( + model = AsrModel( encoder_embed=encoder_embed, encoder=encoder, decoder=decoder, joiner=joiner, encoder_dim=int(max(params.encoder_dim.split(","))), decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, ) return model