mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
optimize frame_reducer.py (#783)
Co-authored-by: yifanyang <yifanyeung@yifanyangs-MacBook-Pro.local>
This commit is contained in:
parent
7eb2d0edb6
commit
59eb465b3c
@ -66,19 +66,14 @@ class FrameReducer(nn.Module):
|
|||||||
|
|
||||||
padding_mask = make_pad_mask(x_lens)
|
padding_mask = make_pad_mask(x_lens)
|
||||||
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
||||||
T_range = torch.arange(x.shape[1], device=x.device)
|
|
||||||
|
|
||||||
frames_list: List[torch.Tensor] = []
|
frames_list: List[torch.Tensor] = []
|
||||||
lens_list: List[int] = []
|
lens_list: List[int] = []
|
||||||
for i in range(x.shape[0]):
|
for i in range(x.shape[0]):
|
||||||
indexes = torch.masked_select(
|
frames = x[i][non_blank_mask[i]]
|
||||||
T_range,
|
|
||||||
non_blank_mask[i],
|
|
||||||
)
|
|
||||||
frames = x[i][indexes]
|
|
||||||
frames_list.append(frames)
|
frames_list.append(frames)
|
||||||
lens_list.append(frames.shape[0])
|
lens_list.append(frames.shape[0])
|
||||||
x_fr = pad_sequence(frames_list).transpose(0, 1)
|
x_fr = pad_sequence(frames_list, batch_first=True)
|
||||||
x_lens_fr = torch.tensor(lens_list).to(device=x.device)
|
x_lens_fr = torch.tensor(lens_list).to(device=x.device)
|
||||||
|
|
||||||
return x_fr, x_lens_fr
|
return x_fr, x_lens_fr
|
||||||
|
Loading…
x
Reference in New Issue
Block a user