from local

This commit is contained in:
dohe0342 2023-02-25 16:22:53 +09:00
parent 296369343a
commit a31218ada2
2 changed files with 1 additions and 2 deletions

View File

@ -238,8 +238,7 @@ def greedy_search(
""" """
ctc_probs = probs ctc_probs = probs
_, max_index = ctc_probs.max(2) # (B, maxlen) _, max_index = ctc_probs.max(2) # (B, maxlen)
print(mask.size()) max_index = max_index.masked_fill_(mask[:,:max_index.size(1)], 0) # (B, maxlen)
max_index = max_index.masked_fill_(mask, 0) # (B, maxlen)
ret_hyps = [] ret_hyps = []
for hyp in max_index: for hyp in max_index: