mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Fix
This commit is contained in:
parent
438fef7215
commit
ccdacc1b44
1
egs/librispeech/ASR/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
Submodule
1
egs/librispeech/ASR/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 9417bd9bc4aae7ada8b7943f5849828eecbf3c91
|
@ -81,20 +81,20 @@ class FrameReducer(nn.Module):
|
|||||||
fake_limit_indexes = torch.topk(
|
fake_limit_indexes = torch.topk(
|
||||||
ctc_output[:, :, blank_id], max_limit_len
|
ctc_output[:, :, blank_id], max_limit_len
|
||||||
).indices
|
).indices
|
||||||
T = (
|
T_arange = (
|
||||||
torch.arange(max_limit_len)
|
torch.arange(max_limit_len)
|
||||||
.expand_as(
|
.expand_as(
|
||||||
fake_limit_indexes,
|
fake_limit_indexes,
|
||||||
)
|
)
|
||||||
.to(device=x.device)
|
.to(device=x.device)
|
||||||
)
|
)
|
||||||
T = torch.remainder(T, limit_lens.unsqueeze(1))
|
T_arange = torch.remainder(T_arange, limit_lens.unsqueeze(1))
|
||||||
limit_indexes = torch.gather(fake_limit_indexes, 1, T)
|
limit_indexes = torch.gather(fake_limit_indexes, 1, T_arange)
|
||||||
limit_mask = torch.full_like(
|
limit_mask = torch.full_like(
|
||||||
non_blank_mask,
|
non_blank_mask,
|
||||||
False,
|
0,
|
||||||
device=x.device,
|
device=x.device,
|
||||||
).scatter_(1, limit_indexes, True)
|
).scatter_(1, limit_indexes, 1)
|
||||||
|
|
||||||
non_blank_mask = non_blank_mask | ~limit_mask
|
non_blank_mask = non_blank_mask | ~limit_mask
|
||||||
|
|
||||||
@ -108,9 +108,9 @@ class FrameReducer(nn.Module):
|
|||||||
)
|
)
|
||||||
- out_lens
|
- out_lens
|
||||||
)
|
)
|
||||||
max_pad_len = pad_lens_list.max()
|
max_pad_len = int(pad_lens_list.max())
|
||||||
|
|
||||||
out = F.pad(x, (0, 0, 0, max_pad_len))
|
out = F.pad(x, [0, 0, 0, max_pad_len])
|
||||||
|
|
||||||
valid_pad_mask = ~make_pad_mask(pad_lens_list)
|
valid_pad_mask = ~make_pad_mask(pad_lens_list)
|
||||||
total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)
|
total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user