diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 4b6cab2d8..c7dc693ea 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1052,7 +1052,8 @@ class RelPositionalEncoding(torch.nn.Module): + x.size(0), : ] - pos_emb = pos_emb.unsqueeze(0) # now: (1, 2*time-1, embed_dim) + batch_size = x.size(1) + pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1) # now: (batch_size, 2*time-1, embed_dim) return self.dropout(pos_emb) @@ -1176,8 +1177,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) - # pos shape now: (head, 1, pos_dim, seq_len2) + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # [where seq_len2 represents relative position.]