mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
1d3ae3a0e6
commit
ab5a9f5ca2
Binary file not shown.
@ -163,6 +163,30 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch_greedy(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
HLG: Optional[k2.Fsa],
|
||||||
|
H: Optional[k2.Fsa],
|
||||||
|
batch: dict,
|
||||||
|
lexicon: Lexicon,
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
token_dict: dict,
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
|
||||||
|
if params.method == 'greedy-search' or params.method == 'ctc-decoding':
|
||||||
|
batch_size = nnet_output.size(0)
|
||||||
|
for i in range(batch_size):
|
||||||
|
topk_log_probs, topk_indexes = nnet_output[i].topk(1)
|
||||||
|
topk_indexes = topk_indexes.squeeze().unique_consecutive()
|
||||||
|
topk_indexes = topk_indexes[topk_indexes != 0]
|
||||||
|
hyp = ''
|
||||||
|
for idx in topk_indexes:
|
||||||
|
hyp += token_dict[idx.item()]
|
||||||
|
print(hyp)
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user