Fix for style check

This commit is contained in:
yifanyang 2023-02-06 18:42:45 +08:00
parent f43ea5db30
commit 5cf4e2df07

View File

@ -592,7 +592,7 @@ def greedy_search(
1, context_size + 1 1, context_size + 1
) )
k[:, 0] = torch.sum( k = torch.sum(
( (
c[:, -context_size - 1 : -1] c[:, -context_size - 1 : -1]
== c[:, -1].expand_as(c[:, -context_size - 1 : -1]) == c[:, -1].expand_as(c[:, -context_size - 1 : -1])
@ -671,11 +671,9 @@ def greedy_search_batch(
# on which hyp[n][i] is decoded # on which hyp[n][i] is decoded
timestamps = [[] for _ in range(N)] timestamps = [[] for _ in range(N)]
decoder_input = torch.tensor( decoder_input = torch.tensor(hyps, device=device, dtype=torch.int64)[
hyps, :, -context_size:
device=device, ] # (N, context_size)
dtype=torch.int64,
)[:, -context_size : ] # (N, context_size)
k = torch.zeros(N, 1, device=device, dtype=torch.int64) k = torch.zeros(N, 1, device=device, dtype=torch.int64)
@ -720,10 +718,12 @@ def greedy_search_batch(
dtype=torch.int64, dtype=torch.int64,
) )
k[:, 0] = torch.sum( k = torch.sum(
( (
c[:, : context_size] c[:, : context_size]
== c[:, context_size : context_size + 1].expand_as(c[:, : context_size]) == c[:, context_size : context_size + 1].expand_as(
c[:, : context_size]
)
).int(), ).int(),
dim=1, dim=1,
keepdim=True, keepdim=True,