From 4eb3e9784847e125468d8f017aea476f68e24ad7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 16 Dec 2022 17:59:15 +0800 Subject: [PATCH] Remove bias from SimpleUpsample, add one to AttentionDownsample --- .../ASR/pruned_transducer_stateless7/zipformer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 42dfa642f..692f8ffbc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -799,6 +799,7 @@ class AttentionDownsample(torch.nn.Module): """ super(AttentionDownsample, self).__init__() self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + self.bias = nn.Parameter(torch.zeros(downsample)) self.name = None # will be set from training code self.dropout = copy.deepcopy(dropout) @@ -832,8 +833,9 @@ class AttentionDownsample(torch.nn.Module): assert src.shape[0] == d_seq_len * ds src = src.reshape(d_seq_len, ds, batch_size, in_channels) - # scores: (d_seq_len, downsample, batch_size) + # scores: (d_seq_len, downsample, batch_size, 1) scores = (src * self.query).sum(dim=-1, keepdim=True) + scores = scores + self.bias.unsqueeze(-1).unsqueeze(-1) scores = penalize_abs_values_gt(scores, limit=20.0, @@ -869,7 +871,7 @@ class SimpleUpsample(torch.nn.Module): num_channels: int, upsample: int): super(SimpleUpsample, self).__init__() - self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) + self.upsample = upsample def forward(self, src: Tensor) -> Tensor: @@ -878,10 +880,9 @@ class SimpleUpsample(torch.nn.Module): Returns a tensor of shape ( (seq_len*upsample), batch_size, num_channels) """ - upsample = self.bias.shape[0] + upsample = self.upsample (seq_len, batch_size, num_channels) = src.shape src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src + self.bias.unsqueeze(1) src = src.reshape(seq_len * upsample, batch_size, num_channels) return src