mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
update for new AsrModel
This commit is contained in:
parent
c7c0bbff81
commit
4325bb20b9
@ -50,6 +50,7 @@ def fast_beam_search_one_best(
|
|||||||
subtract_ilme: bool = False,
|
subtract_ilme: bool = False,
|
||||||
ilme_scale: float = 0.1,
|
ilme_scale: float = 0.1,
|
||||||
return_timestamps: bool = False,
|
return_timestamps: bool = False,
|
||||||
|
allow_partial: bool = False,
|
||||||
) -> Union[List[List[int]], DecodingResults]:
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -92,6 +93,7 @@ def fast_beam_search_one_best(
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
subtract_ilme=subtract_ilme,
|
subtract_ilme=subtract_ilme,
|
||||||
ilme_scale=ilme_scale,
|
ilme_scale=ilme_scale,
|
||||||
|
allow_partial=allow_partial,
|
||||||
)
|
)
|
||||||
|
|
||||||
best_path = one_best_decoding(lattice)
|
best_path = one_best_decoding(lattice)
|
||||||
@ -115,6 +117,7 @@ def fast_beam_search_nbest_LG(
|
|||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
return_timestamps: bool = False,
|
return_timestamps: bool = False,
|
||||||
|
allow_partial: bool = False,
|
||||||
) -> Union[List[List[int]], DecodingResults]:
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -168,6 +171,7 @@ def fast_beam_search_nbest_LG(
|
|||||||
max_states=max_states,
|
max_states=max_states,
|
||||||
max_contexts=max_contexts,
|
max_contexts=max_contexts,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
allow_partial=allow_partial,
|
||||||
)
|
)
|
||||||
|
|
||||||
nbest = Nbest.from_lattice(
|
nbest = Nbest.from_lattice(
|
||||||
@ -241,6 +245,7 @@ def fast_beam_search_nbest(
|
|||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
return_timestamps: bool = False,
|
return_timestamps: bool = False,
|
||||||
|
allow_partial: bool = False,
|
||||||
) -> Union[List[List[int]], DecodingResults]:
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -294,6 +299,7 @@ def fast_beam_search_nbest(
|
|||||||
max_states=max_states,
|
max_states=max_states,
|
||||||
max_contexts=max_contexts,
|
max_contexts=max_contexts,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
allow_partial=allow_partial,
|
||||||
)
|
)
|
||||||
|
|
||||||
nbest = Nbest.from_lattice(
|
nbest = Nbest.from_lattice(
|
||||||
@ -332,6 +338,7 @@ def fast_beam_search_nbest_oracle(
|
|||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
return_timestamps: bool = False,
|
return_timestamps: bool = False,
|
||||||
|
allow_partial: bool = False,
|
||||||
) -> Union[List[List[int]], DecodingResults]:
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -389,6 +396,7 @@ def fast_beam_search_nbest_oracle(
|
|||||||
max_states=max_states,
|
max_states=max_states,
|
||||||
max_contexts=max_contexts,
|
max_contexts=max_contexts,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
allow_partial=allow_partial,
|
||||||
)
|
)
|
||||||
|
|
||||||
nbest = Nbest.from_lattice(
|
nbest = Nbest.from_lattice(
|
||||||
@ -434,6 +442,7 @@ def fast_beam_search(
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
subtract_ilme: bool = False,
|
subtract_ilme: bool = False,
|
||||||
ilme_scale: float = 0.1,
|
ilme_scale: float = 0.1,
|
||||||
|
allow_partial: bool = False,
|
||||||
) -> k2.Fsa:
|
) -> k2.Fsa:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -517,7 +526,9 @@ def fast_beam_search(
|
|||||||
log_probs -= ilme_scale * ilme_log_probs
|
log_probs -= ilme_scale * ilme_log_probs
|
||||||
decoding_streams.advance(log_probs)
|
decoding_streams.advance(log_probs)
|
||||||
decoding_streams.terminate_and_flush_to_streams()
|
decoding_streams.terminate_and_flush_to_streams()
|
||||||
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
lattice = decoding_streams.format_output(
|
||||||
|
encoder_out_lens.tolist(), allow_partial=allow_partial
|
||||||
|
)
|
||||||
|
|
||||||
return lattice
|
return lattice
|
||||||
|
|
||||||
|
@ -385,6 +385,7 @@ def decode_one_batch(
|
|||||||
beam=params.beam,
|
beam=params.beam,
|
||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
|
allow_partial=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyp = [w for w in hyp.split() if w != unk]
|
hyp = [w for w in hyp.split() if w != unk]
|
||||||
@ -400,6 +401,7 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
allow_partial=True,
|
||||||
)
|
)
|
||||||
for hyp in hyp_tokens:
|
for hyp in hyp_tokens:
|
||||||
hyp = [word_table[i] for i in hyp if word_table[i] != unk]
|
hyp = [word_table[i] for i in hyp if word_table[i] != unk]
|
||||||
@ -415,6 +417,7 @@ def decode_one_batch(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
allow_partial=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyp = [w for w in hyp.split() if w != unk]
|
hyp = [w for w in hyp.split() if w != unk]
|
||||||
@ -431,6 +434,7 @@ def decode_one_batch(
|
|||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
ref_texts=sp.encode(supervisions["text"]),
|
ref_texts=sp.encode(supervisions["text"]),
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
allow_partial=True,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyp = [w for w in hyp.split() if w != unk]
|
hyp = [w for w in hyp.split() if w != unk]
|
||||||
|
@ -68,7 +68,7 @@ from lhotse.cut import Cut
|
|||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids
|
from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids
|
||||||
from model import Transducer
|
from model import AsrModel
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from scaling import ScheduledFloat
|
from scaling import ScheduledFloat
|
||||||
from subsampling import Conv2dSubsampling
|
from subsampling import Conv2dSubsampling
|
||||||
@ -585,14 +585,13 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
|
||||||
model = Transducer(
|
model = AsrModel(
|
||||||
encoder_embed=encoder_embed,
|
encoder_embed=encoder_embed,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
encoder_dim=int(max(params.encoder_dim.split(","))),
|
encoder_dim=int(max(params.encoder_dim.split(","))),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user