update for new AsrModel

This commit is contained in:
Desh Raj 2023-06-22 02:30:53 -04:00
parent c7c0bbff81
commit 4325bb20b9
3 changed files with 18 additions and 4 deletions

View File

@ -50,6 +50,7 @@ def fast_beam_search_one_best(
subtract_ilme: bool = False, subtract_ilme: bool = False,
ilme_scale: float = 0.1, ilme_scale: float = 0.1,
return_timestamps: bool = False, return_timestamps: bool = False,
allow_partial: bool = False,
) -> Union[List[List[int]], DecodingResults]: ) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -92,6 +93,7 @@ def fast_beam_search_one_best(
temperature=temperature, temperature=temperature,
subtract_ilme=subtract_ilme, subtract_ilme=subtract_ilme,
ilme_scale=ilme_scale, ilme_scale=ilme_scale,
allow_partial=allow_partial,
) )
best_path = one_best_decoding(lattice) best_path = one_best_decoding(lattice)
@ -115,6 +117,7 @@ def fast_beam_search_nbest_LG(
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0, temperature: float = 1.0,
return_timestamps: bool = False, return_timestamps: bool = False,
allow_partial: bool = False,
) -> Union[List[List[int]], DecodingResults]: ) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -168,6 +171,7 @@ def fast_beam_search_nbest_LG(
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature, temperature=temperature,
allow_partial=allow_partial,
) )
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
@ -241,6 +245,7 @@ def fast_beam_search_nbest(
use_double_scores: bool = True, use_double_scores: bool = True,
temperature: float = 1.0, temperature: float = 1.0,
return_timestamps: bool = False, return_timestamps: bool = False,
allow_partial: bool = False,
) -> Union[List[List[int]], DecodingResults]: ) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -294,6 +299,7 @@ def fast_beam_search_nbest(
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature, temperature=temperature,
allow_partial=allow_partial,
) )
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
@ -332,6 +338,7 @@ def fast_beam_search_nbest_oracle(
nbest_scale: float = 0.5, nbest_scale: float = 0.5,
temperature: float = 1.0, temperature: float = 1.0,
return_timestamps: bool = False, return_timestamps: bool = False,
allow_partial: bool = False,
) -> Union[List[List[int]], DecodingResults]: ) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -389,6 +396,7 @@ def fast_beam_search_nbest_oracle(
max_states=max_states, max_states=max_states,
max_contexts=max_contexts, max_contexts=max_contexts,
temperature=temperature, temperature=temperature,
allow_partial=allow_partial,
) )
nbest = Nbest.from_lattice( nbest = Nbest.from_lattice(
@ -434,6 +442,7 @@ def fast_beam_search(
temperature: float = 1.0, temperature: float = 1.0,
subtract_ilme: bool = False, subtract_ilme: bool = False,
ilme_scale: float = 0.1, ilme_scale: float = 0.1,
allow_partial: bool = False,
) -> k2.Fsa: ) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
@ -517,7 +526,9 @@ def fast_beam_search(
log_probs -= ilme_scale * ilme_log_probs log_probs -= ilme_scale * ilme_log_probs
decoding_streams.advance(log_probs) decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams() decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist()) lattice = decoding_streams.format_output(
encoder_out_lens.tolist(), allow_partial=allow_partial
)
return lattice return lattice

View File

@ -385,6 +385,7 @@ def decode_one_batch(
beam=params.beam, beam=params.beam,
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
allow_partial=True,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyp = [w for w in hyp.split() if w != unk] hyp = [w for w in hyp.split() if w != unk]
@ -400,6 +401,7 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
allow_partial=True,
) )
for hyp in hyp_tokens: for hyp in hyp_tokens:
hyp = [word_table[i] for i in hyp if word_table[i] != unk] hyp = [word_table[i] for i in hyp if word_table[i] != unk]
@ -415,6 +417,7 @@ def decode_one_batch(
max_states=params.max_states, max_states=params.max_states,
num_paths=params.num_paths, num_paths=params.num_paths,
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
allow_partial=True,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyp = [w for w in hyp.split() if w != unk] hyp = [w for w in hyp.split() if w != unk]
@ -431,6 +434,7 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]), ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale, nbest_scale=params.nbest_scale,
allow_partial=True,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyp = [w for w in hyp.split() if w != unk] hyp = [w for w in hyp.split() if w != unk]

View File

@ -68,7 +68,7 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids
from model import Transducer from model import AsrModel
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
@ -585,14 +585,13 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
model = Transducer( model = AsrModel(
encoder_embed=encoder_embed, encoder_embed=encoder_embed,
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(","))), encoder_dim=int(max(params.encoder_dim.split(","))),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
) )
return model return model