diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index ada0c2ca7..b97bd51ee 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -504,32 +504,37 @@ class ZipformerEncoderLayer(nn.Module): ) # else rely on the ones passed in + # use different heads for nonlin_attention_module and attention_squeeze, depending + # whether this module has its on self_attn_weights submodule or is borrowing + # attention weights from another one. + head_offset = 0 if self.self_attn_weights is not None else 2 + if self.training and random.random() < float(self.layer_skip_rate): # skip the layer return src, attn_weights use_self_attn = (random.random() >= dynamic_skip_rate) if use_self_attn: - first_attn_weights = attn_weights[0:3] + selected_attn_weights = attn_weights[head_offset:head_offset+2] if random.random() < float(self.const_attention_rate): # Make attention weights constant. The intention is to # encourage these modules to do something similar to an # averaging-over-time operation. # only need the mask, can just use the 1st one and expand later - first_attn_weights = first_attn_weights[0:1] - first_attn_weights = (first_attn_weights > 0.0).to(first_attn_weights.dtype) - first_attn_weights = first_attn_weights * (1.0 / first_attn_weights.sum(dim=-1, keepdim=True)) - first_attn_weights = first_attn_weights.expand(3, -1, -1, -1) + selected_attn_weights = selected_attn_weights[0:1] + selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) + selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) + selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1) if torch.jit.is_scripting() or use_self_attn: src = src + self.nonlin_attention_module(src, - first_attn_weights[0:1]) + selected_attn_weights[0:1]) src = src + self.feed_forward1(src) # pooling module if torch.jit.is_scripting() or use_self_attn: - src = src + self.attention_squeeze(src, first_attn_weights[1:2]) + src = src + self.attention_squeeze(src, selected_attn_weights[1:2]) if torch.jit.is_scripting() or use_self_attn: src = src + self.self_attn(