mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
check tensor dimension
This commit is contained in:
parent
d68532e0c9
commit
356056cea2
@ -103,6 +103,15 @@ class Transducer(nn.Module):
|
|||||||
y_padded = y.pad(mode="constant", padding_value=0)
|
y_padded = y.pad(mode="constant", padding_value=0)
|
||||||
|
|
||||||
y_padded = y_padded.to(torch.int64)
|
y_padded = y_padded.to(torch.int64)
|
||||||
|
|
||||||
|
max_sym_id = torch.max(y_padded)
|
||||||
|
assert encoder_out.size(0) == decoder_out.size(0),\
|
||||||
|
[encoder_out.size(), decoder_out.size()]
|
||||||
|
assert encoder_out.size(2) == decoder_out.size(2),\
|
||||||
|
[encoder_out.size(), decoder_out.size()]
|
||||||
|
assert encoder_out.size(2) >= (max_sym_id + 1),\
|
||||||
|
[encoder_out.size(), max_sym_id]
|
||||||
|
|
||||||
boundary = torch.zeros(
|
boundary = torch.zeros(
|
||||||
(x.size(0), 4), dtype=torch.int64, device=x.device
|
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user