From 5cf4e2df077ad9c531712e46258a586fcc1a2e66 Mon Sep 17 00:00:00 2001 From: yifanyang Date: Mon, 6 Feb 2023 18:42:45 +0800 Subject: [PATCH] Fix for style check --- .../pruned_transducer_stateless9/beam_search.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py index 181ad253e..c5ae0669e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py @@ -592,7 +592,7 @@ def greedy_search( 1, context_size + 1 ) - k[:, 0] = torch.sum( + k = torch.sum( ( c[:, -context_size - 1 : -1] == c[:, -1].expand_as(c[:, -context_size - 1 : -1]) @@ -671,11 +671,9 @@ def greedy_search_batch( # on which hyp[n][i] is decoded timestamps = [[] for _ in range(N)] - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - )[:, -context_size : ] # (N, context_size) + decoder_input = torch.tensor(hyps, device=device, dtype=torch.int64)[ + :, -context_size: + ] # (N, context_size) k = torch.zeros(N, 1, device=device, dtype=torch.int64) @@ -720,10 +718,12 @@ def greedy_search_batch( dtype=torch.int64, ) - k[:, 0] = torch.sum( + k = torch.sum( ( c[:, : context_size] - == c[:, context_size : context_size + 1].expand_as(c[:, : context_size]) + == c[:, context_size : context_size + 1].expand_as( + c[:, : context_size] + ) ).int(), dim=1, keepdim=True,