from local

This commit is contained in:
dohe0342 2023-02-02 19:08:16 +09:00
parent 9c07cd95db
commit 5aadc8d5f9
2 changed files with 31 additions and 0 deletions

View File

@ -174,6 +174,37 @@ def decode_one_batch_greedy(
eos_id: int, eos_id: int,
token_dict: dict, token_dict: dict,
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
if HLG is not None:
device = HLG.device
else:
device = H.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is (N, T, C)
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
),
1,
).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
decoding_graph = H
if params.method == 'greedy-search' or params.method == 'ctc-decoding': if params.method == 'greedy-search' or params.method == 'ctc-decoding':
batch_size = nnet_output.size(0) batch_size = nnet_output.size(0)