optimize frame_reducer.py (#783)

Co-authored-by: yifanyang <yifanyeung@yifanyangs-MacBook-Pro.local>
This commit is contained in:
Yifan Yang 2022-12-23 17:55:36 +08:00 committed by GitHub
parent 7eb2d0edb6
commit 59eb465b3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -66,19 +66,14 @@ class FrameReducer(nn.Module):
padding_mask = make_pad_mask(x_lens) padding_mask = make_pad_mask(x_lens)
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask) non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
T_range = torch.arange(x.shape[1], device=x.device)
frames_list: List[torch.Tensor] = [] frames_list: List[torch.Tensor] = []
lens_list: List[int] = [] lens_list: List[int] = []
for i in range(x.shape[0]): for i in range(x.shape[0]):
indexes = torch.masked_select( frames = x[i][non_blank_mask[i]]
T_range,
non_blank_mask[i],
)
frames = x[i][indexes]
frames_list.append(frames) frames_list.append(frames)
lens_list.append(frames.shape[0]) lens_list.append(frames.shape[0])
x_fr = pad_sequence(frames_list).transpose(0, 1) x_fr = pad_sequence(frames_list, batch_first=True)
x_lens_fr = torch.tensor(lens_list).to(device=x.device) x_lens_fr = torch.tensor(lens_list).to(device=x.device)
return x_fr, x_lens_fr return x_fr, x_lens_fr