mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
optimize frame_reducer.py (#783)
Co-authored-by: yifanyang <yifanyeung@yifanyangs-MacBook-Pro.local>
This commit is contained in:
parent
7eb2d0edb6
commit
59eb465b3c
@ -66,19 +66,14 @@ class FrameReducer(nn.Module):
|
||||
|
||||
padding_mask = make_pad_mask(x_lens)
|
||||
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
||||
T_range = torch.arange(x.shape[1], device=x.device)
|
||||
|
||||
frames_list: List[torch.Tensor] = []
|
||||
lens_list: List[int] = []
|
||||
for i in range(x.shape[0]):
|
||||
indexes = torch.masked_select(
|
||||
T_range,
|
||||
non_blank_mask[i],
|
||||
)
|
||||
frames = x[i][indexes]
|
||||
frames = x[i][non_blank_mask[i]]
|
||||
frames_list.append(frames)
|
||||
lens_list.append(frames.shape[0])
|
||||
x_fr = pad_sequence(frames_list).transpose(0, 1)
|
||||
x_fr = pad_sequence(frames_list, batch_first=True)
|
||||
x_lens_fr = torch.tensor(lens_list).to(device=x.device)
|
||||
|
||||
return x_fr, x_lens_fr
|
||||
|
Loading…
x
Reference in New Issue
Block a user