Fix to joiner to allow different dims

This commit is contained in:
Daniel Povey 2022-04-04 13:34:43 +08:00
parent 9f62a0296c
commit 0fd0828f79

View File

@ -47,7 +47,7 @@ class Joiner(nn.Module):
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.shape == decoder_out.shape
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)