From e67d4ca40d396996d3be1dd46b247c41215e0cf3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 12 Nov 2022 19:21:29 +0800 Subject: [PATCH] Make pos_emb be dropped out independently across batch --- .../ASR/pruned_transducer_stateless7/zipformer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 4b6cab2d8..c7dc693ea 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1052,7 +1052,8 @@ class RelPositionalEncoding(torch.nn.Module): + x.size(0), : ] - pos_emb = pos_emb.unsqueeze(0) # now: (1, 2*time-1, embed_dim) + batch_size = x.size(1) + pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1) # now: (batch_size, 2*time-1, embed_dim) return self.dropout(pos_emb) @@ -1176,8 +1177,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) - # pos shape now: (head, 1, pos_dim, seq_len2) + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # [where seq_len2 represents relative position.]