update for new AsrModel

This commit is contained in:
Desh Raj 2023-06-22 02:30:53 -04:00
parent c7c0bbff81
commit 4325bb20b9
3 changed files with 18 additions and 4 deletions

View File

@ -50,6 +50,7 @@ def fast_beam_search_one_best(
subtract_ilme: bool = False,
ilme_scale: float = 0.1,
return_timestamps: bool = False,
allow_partial: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
@ -92,6 +93,7 @@ def fast_beam_search_one_best(
temperature=temperature,
subtract_ilme=subtract_ilme,
ilme_scale=ilme_scale,
allow_partial=allow_partial,
)
best_path = one_best_decoding(lattice)
@ -115,6 +117,7 @@ def fast_beam_search_nbest_LG(
use_double_scores: bool = True,
temperature: float = 1.0,
return_timestamps: bool = False,
allow_partial: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
@ -168,6 +171,7 @@ def fast_beam_search_nbest_LG(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
allow_partial=allow_partial,
)
nbest = Nbest.from_lattice(
@ -241,6 +245,7 @@ def fast_beam_search_nbest(
use_double_scores: bool = True,
temperature: float = 1.0,
return_timestamps: bool = False,
allow_partial: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
@ -294,6 +299,7 @@ def fast_beam_search_nbest(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
allow_partial=allow_partial,
)
nbest = Nbest.from_lattice(
@ -332,6 +338,7 @@ def fast_beam_search_nbest_oracle(
nbest_scale: float = 0.5,
temperature: float = 1.0,
return_timestamps: bool = False,
allow_partial: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
@ -389,6 +396,7 @@ def fast_beam_search_nbest_oracle(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
allow_partial=allow_partial,
)
nbest = Nbest.from_lattice(
@ -434,6 +442,7 @@ def fast_beam_search(
temperature: float = 1.0,
subtract_ilme: bool = False,
ilme_scale: float = 0.1,
allow_partial: bool = False,
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.
@ -517,7 +526,9 @@ def fast_beam_search(
log_probs -= ilme_scale * ilme_log_probs
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
lattice = decoding_streams.format_output(
encoder_out_lens.tolist(), allow_partial=allow_partial
)
return lattice

View File

@ -385,6 +385,7 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
allow_partial=True,
)
for hyp in sp.decode(hyp_tokens):
hyp = [w for w in hyp.split() if w != unk]
@ -400,6 +401,7 @@ def decode_one_batch(
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
allow_partial=True,
)
for hyp in hyp_tokens:
hyp = [word_table[i] for i in hyp if word_table[i] != unk]
@ -415,6 +417,7 @@ def decode_one_batch(
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
allow_partial=True,
)
for hyp in sp.decode(hyp_tokens):
hyp = [w for w in hyp.split() if w != unk]
@ -431,6 +434,7 @@ def decode_one_batch(
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
allow_partial=True,
)
for hyp in sp.decode(hyp_tokens):
hyp = [w for w in hyp.split() if w != unk]

View File

@ -68,7 +68,7 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids
from model import Transducer
from model import AsrModel
from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
@ -585,14 +585,13 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
model = AsrModel(
encoder_embed=encoder_embed,
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(","))),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
)
return model