mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add support for export.py
This commit is contained in:
parent
63c7402297
commit
d5ad908562
@ -74,7 +74,7 @@ class FrameReducer(nn.Module):
|
|||||||
padding_mask = make_pad_mask(x_lens)
|
padding_mask = make_pad_mask(x_lens)
|
||||||
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
||||||
|
|
||||||
if y_lens is not None or self.training is False:
|
if y_lens is not None or self.training == False:
|
||||||
# Limit the maximum number of reduced frames
|
# Limit the maximum number of reduced frames
|
||||||
if y_lens is not None:
|
if y_lens is not None:
|
||||||
limit_lens = T - y_lens
|
limit_lens = T - y_lens
|
||||||
@ -93,12 +93,12 @@ class FrameReducer(nn.Module):
|
|||||||
.to(device=x.device)
|
.to(device=x.device)
|
||||||
)
|
)
|
||||||
T = torch.remainder(T, limit_lens.unsqueeze(1))
|
T = torch.remainder(T, limit_lens.unsqueeze(1))
|
||||||
limit_indexes = torch.gather(fake_limit_indexes, 1, T)
|
limit_indexes = torch.gather(fake_limit_indexes, 1, torch.tensor(T))
|
||||||
limit_mask = torch.full_like(
|
limit_mask = (torch.full_like(
|
||||||
non_blank_mask,
|
non_blank_mask,
|
||||||
False,
|
0,
|
||||||
device=x.device,
|
device=x.device,
|
||||||
).scatter_(1, limit_indexes, True)
|
).scatter_(1, limit_indexes, 1) == 1)
|
||||||
|
|
||||||
non_blank_mask = non_blank_mask | ~limit_mask
|
non_blank_mask = non_blank_mask | ~limit_mask
|
||||||
|
|
||||||
@ -112,7 +112,7 @@ class FrameReducer(nn.Module):
|
|||||||
)
|
)
|
||||||
- out_lens
|
- out_lens
|
||||||
)
|
)
|
||||||
max_pad_len = pad_lens_list.max()
|
max_pad_len = int(pad_lens_list.max().item())
|
||||||
|
|
||||||
out = F.pad(x, (0, 0, 0, max_pad_len))
|
out = F.pad(x, (0, 0, 0, max_pad_len))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user