mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
9c07cd95db
commit
5aadc8d5f9
Binary file not shown.
@ -175,6 +175,37 @@ def decode_one_batch_greedy(
|
|||||||
token_dict: dict,
|
token_dict: dict,
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
|
||||||
|
if HLG is not None:
|
||||||
|
device = HLG.device
|
||||||
|
else:
|
||||||
|
device = H.device
|
||||||
|
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
|
||||||
|
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||||
|
# nnet_output is (N, T, C)
|
||||||
|
|
||||||
|
supervision_segments = torch.stack(
|
||||||
|
(
|
||||||
|
supervisions["sequence_idx"],
|
||||||
|
supervisions["start_frame"] // params.subsampling_factor,
|
||||||
|
supervisions["num_frames"] // params.subsampling_factor,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
).to(torch.int32)
|
||||||
|
|
||||||
|
if H is None:
|
||||||
|
assert HLG is not None
|
||||||
|
decoding_graph = HLG
|
||||||
|
else:
|
||||||
|
assert HLG is None
|
||||||
|
decoding_graph = H
|
||||||
|
|
||||||
if params.method == 'greedy-search' or params.method == 'ctc-decoding':
|
if params.method == 'greedy-search' or params.method == 'ctc-decoding':
|
||||||
batch_size = nnet_output.size(0)
|
batch_size = nnet_output.size(0)
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user