From ba69eb48fe94d58730bc11fdb2f0705bc8bc679a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 14 Nov 2022 15:31:56 +0800 Subject: [PATCH] Remove pos_emb schedule --- .../pruned_transducer_stateless7/zipformer.py | 40 ++++++++----------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 313fdf56e..eba5806a2 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1014,8 +1014,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module): query_head_dim: dimension of the query (and key), per head. e.g. 24. pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. dropout: dropout probability for attn_output_weights. Default: 0.0. - pos_emb_skip: probability for skipping the pos_emb part of the scores on - any given call to forward(), in training time. """ def __init__( @@ -1026,8 +1024,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module): query_head_dim: int, pos_head_dim: int, dropout: float = 0.0, - pos_emb_skip: FloatLike = ScheduledFloat((0.0, 0.5), - (4000.0, 0.025), default=0.0) ) -> None: super().__init__() self.embed_dim = embed_dim @@ -1035,7 +1031,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self.query_head_dim = query_head_dim self.pos_head_dim = pos_head_dim self.dropout = dropout - self.pos_emb_skip = copy.deepcopy(pos_emb_skip) key_head_dim = query_head_dim in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads @@ -1120,26 +1115,25 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores = torch.matmul(q, k) - if not self.training or random.random() >= float(self.pos_emb_skip): - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len), - (pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2)-pos_scores.stride(3), - pos_scores.stride(3)), - storage_offset=pos_scores.stride(3) * (seq_len - 1)) + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len), + (pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2)-pos_scores.stride(3), + pos_scores.stride(3)), + storage_offset=pos_scores.stride(3) * (seq_len - 1)) - attn_scores = attn_scores + pos_scores + attn_scores = attn_scores + pos_scores if self.training and random.random() < 0.1: # This is a harder way of limiting the attention scores to not be