Update frame_reducer.py to avoid warning on training mode.

This commit is contained in:
drawfish 2023-05-30 12:45:27 +08:00 committed by GitHub
parent d5ad908562
commit 4f316e98d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -74,7 +74,7 @@ class FrameReducer(nn.Module):
padding_mask = make_pad_mask(x_lens) padding_mask = make_pad_mask(x_lens)
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
if y_lens is not None or self.training == False: if y_lens is not None or self.training is False:
# Limit the maximum number of reduced frames # Limit the maximum number of reduced frames
if y_lens is not None: if y_lens is not None:
limit_lens = T - y_lens limit_lens = T - y_lens
@ -85,15 +85,15 @@ class FrameReducer(nn.Module):
fake_limit_indexes = torch.topk( fake_limit_indexes = torch.topk(
ctc_output[:, :, blank_id], max_limit_len ctc_output[:, :, blank_id], max_limit_len
).indices ).indices
T = ( _T = (
torch.arange(max_limit_len) torch.arange(max_limit_len)
.expand_as( .expand_as(
fake_limit_indexes, fake_limit_indexes,
) )
.to(device=x.device) .to(device=x.device)
) )
T = torch.remainder(T, limit_lens.unsqueeze(1)) _T = torch.remainder(_T, limit_lens.unsqueeze(1))
limit_indexes = torch.gather(fake_limit_indexes, 1, torch.tensor(T)) limit_indexes = torch.gather(fake_limit_indexes, 1, _T)
limit_mask = (torch.full_like( limit_mask = (torch.full_like(
non_blank_mask, non_blank_mask,
0, 0,