mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 02:52:18 +00:00
Fix for style check
This commit is contained in:
parent
f43ea5db30
commit
5cf4e2df07
@ -592,7 +592,7 @@ def greedy_search(
|
|||||||
1, context_size + 1
|
1, context_size + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
k[:, 0] = torch.sum(
|
k = torch.sum(
|
||||||
(
|
(
|
||||||
c[:, -context_size - 1 : -1]
|
c[:, -context_size - 1 : -1]
|
||||||
== c[:, -1].expand_as(c[:, -context_size - 1 : -1])
|
== c[:, -1].expand_as(c[:, -context_size - 1 : -1])
|
||||||
@ -671,11 +671,9 @@ def greedy_search_batch(
|
|||||||
# on which hyp[n][i] is decoded
|
# on which hyp[n][i] is decoded
|
||||||
timestamps = [[] for _ in range(N)]
|
timestamps = [[] for _ in range(N)]
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(hyps, device=device, dtype=torch.int64)[
|
||||||
hyps,
|
:, -context_size:
|
||||||
device=device,
|
] # (N, context_size)
|
||||||
dtype=torch.int64,
|
|
||||||
)[:, -context_size : ] # (N, context_size)
|
|
||||||
|
|
||||||
k = torch.zeros(N, 1, device=device, dtype=torch.int64)
|
k = torch.zeros(N, 1, device=device, dtype=torch.int64)
|
||||||
|
|
||||||
@ -720,10 +718,12 @@ def greedy_search_batch(
|
|||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
|
|
||||||
k[:, 0] = torch.sum(
|
k = torch.sum(
|
||||||
(
|
(
|
||||||
c[:, : context_size]
|
c[:, : context_size]
|
||||||
== c[:, context_size : context_size + 1].expand_as(c[:, : context_size])
|
== c[:, context_size : context_size + 1].expand_as(
|
||||||
|
c[:, : context_size]
|
||||||
|
)
|
||||||
).int(),
|
).int(),
|
||||||
dim=1,
|
dim=1,
|
||||||
keepdim=True,
|
keepdim=True,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user