Update frame_reducer.py to avoid warning on training mode.
This commit is contained in:
parent
d5ad908562
commit
4f316e98d0
@ -74,7 +74,7 @@ class FrameReducer(nn.Module):
|
||||
padding_mask = make_pad_mask(x_lens)
|
||||
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
||||
|
||||
if y_lens is not None or self.training == False:
|
||||
if y_lens is not None or self.training is False:
|
||||
# Limit the maximum number of reduced frames
|
||||
if y_lens is not None:
|
||||
limit_lens = T - y_lens
|
||||
@ -85,15 +85,15 @@ class FrameReducer(nn.Module):
|
||||
fake_limit_indexes = torch.topk(
|
||||
ctc_output[:, :, blank_id], max_limit_len
|
||||
).indices
|
||||
T = (
|
||||
_T = (
|
||||
torch.arange(max_limit_len)
|
||||
.expand_as(
|
||||
fake_limit_indexes,
|
||||
)
|
||||
.to(device=x.device)
|
||||
)
|
||||
T = torch.remainder(T, limit_lens.unsqueeze(1))
|
||||
limit_indexes = torch.gather(fake_limit_indexes, 1, torch.tensor(T))
|
||||
_T = torch.remainder(_T, limit_lens.unsqueeze(1))
|
||||
limit_indexes = torch.gather(fake_limit_indexes, 1, _T)
|
||||
limit_mask = (torch.full_like(
|
||||
non_blank_mask,
|
||||
0,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user