check tensor dimension

This commit is contained in:
PingFeng Luo 2022-01-19 17:15:37 +08:00
parent d68532e0c9
commit 356056cea2

View File

@ -103,6 +103,15 @@ class Transducer(nn.Module):
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64) y_padded = y_padded.to(torch.int64)
max_sym_id = torch.max(y_padded)
assert encoder_out.size(0) == decoder_out.size(0),\
[encoder_out.size(), decoder_out.size()]
assert encoder_out.size(2) == decoder_out.size(2),\
[encoder_out.size(), decoder_out.size()]
assert encoder_out.size(2) >= (max_sym_id + 1),\
[encoder_out.size(), max_sym_id]
boundary = torch.zeros( boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device (x.size(0), 4), dtype=torch.int64, device=x.device
) )