diff --git a/egs/wenetspeech/ASR/transducer_stateless/model.py b/egs/wenetspeech/ASR/transducer_stateless/model.py index 3d562e1a8..016c58594 100644 --- a/egs/wenetspeech/ASR/transducer_stateless/model.py +++ b/egs/wenetspeech/ASR/transducer_stateless/model.py @@ -103,6 +103,15 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) + + max_sym_id = torch.max(y_padded) + assert encoder_out.size(0) == decoder_out.size(0),\ + [encoder_out.size(), decoder_out.size()] + assert encoder_out.size(2) == decoder_out.size(2),\ + [encoder_out.size(), decoder_out.size()] + assert encoder_out.size(2) >= (max_sym_id + 1),\ + [encoder_out.size(), max_sym_id] + boundary = torch.zeros( (x.size(0), 4), dtype=torch.int64, device=x.device )