Merge 4f316e98d0a0c2ed939f748102a45b7803c8379d into ccd8c624dd19c23b3ef576df3329092a78522e6f

This commit is contained in:
drawfish 2023-06-30 14:19:45 +09:00 committed by GitHub
commit 56a073a3b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -74,27 +74,31 @@ class FrameReducer(nn.Module):
padding_mask = make_pad_mask(x_lens)
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
if y_lens is not None:
if y_lens is not None or self.training is False:
# Limit the maximum number of reduced frames
limit_lens = T - y_lens
if y_lens is not None:
limit_lens = T - y_lens
else:
# In eval mode, ensure audio that is completely silent does not make any errors
limit_lens = T - torch.ones_like(x_lens)
max_limit_len = limit_lens.max().int()
fake_limit_indexes = torch.topk(
ctc_output[:, :, blank_id], max_limit_len
).indices
T = (
_T = (
torch.arange(max_limit_len)
.expand_as(
fake_limit_indexes,
)
.to(device=x.device)
)
T = torch.remainder(T, limit_lens.unsqueeze(1))
limit_indexes = torch.gather(fake_limit_indexes, 1, T)
limit_mask = torch.full_like(
_T = torch.remainder(_T, limit_lens.unsqueeze(1))
limit_indexes = torch.gather(fake_limit_indexes, 1, _T)
limit_mask = (torch.full_like(
non_blank_mask,
False,
0,
device=x.device,
).scatter_(1, limit_indexes, True)
).scatter_(1, limit_indexes, 1) == 1)
non_blank_mask = non_blank_mask | ~limit_mask
@ -108,7 +112,7 @@ class FrameReducer(nn.Module):
)
- out_lens
)
max_pad_len = pad_lens_list.max()
max_pad_len = int(pad_lens_list.max().item())
out = F.pad(x, (0, 0, 0, max_pad_len))