From cb12014c31fcaa01947a3cd5f79aa779f6df9cf3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 10 Dec 2022 16:09:51 +0800 Subject: [PATCH] Implement dropout for scores in AttentionDownsample --- .../pruned_transducer_stateless7/zipformer.py | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index e8de57269..29ce8840d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -202,6 +202,7 @@ class Zipformer(EncoderInterface): input_dim=encoder_dim[i-1] if i > 0 else encoder_dim[0], output_dim=encoder_dim[i], downsample=downsampling_factor[i], + dropout=dropout, ) encoders.append(encoder) self.encoders = nn.ModuleList(encoders) @@ -211,7 +212,8 @@ class Zipformer(EncoderInterface): self.downsample_output = AttentionDownsample(encoder_dim[-1], encoder_dim[-1], - downsample=output_downsampling_factor) + downsample=output_downsampling_factor, + dropout=dropout) def _init_skip_modules(self): @@ -677,10 +679,12 @@ class DownsampledZipformerEncoder(nn.Module): encoder: nn.Module, input_dim: int, output_dim: int, - downsample: int): + downsample: int, + dropout: FloatLike): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample - self.downsample = AttentionDownsample(input_dim, output_dim, downsample) + self.downsample = AttentionDownsample(input_dim, output_dim, + downsample, dropout) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) self.out_combiner = SimpleCombiner(input_dim, @@ -794,7 +798,8 @@ class AttentionDownsample(torch.nn.Module): def __init__(self, in_channels: int, out_channels: int, - downsample: int): + downsample: int, + dropout: FloatLike): """ Require out_channels > in_channels. """ @@ -802,6 +807,7 @@ class AttentionDownsample(torch.nn.Module): self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) self.name = None # will be set from training code + self.dropout = copy.deepcopy(dropout) # fill in the extra dimensions with a projection of the input if out_channels > in_channels: @@ -832,6 +838,7 @@ class AttentionDownsample(torch.nn.Module): assert src.shape[0] == d_seq_len * ds src = src.reshape(d_seq_len, ds, batch_size, in_channels) + # scores: (d_seq_len, downsample, batch_size) scores = (src * self.query).sum(dim=-1, keepdim=True) scores = penalize_abs_values_gt(scores, @@ -839,6 +846,14 @@ class AttentionDownsample(torch.nn.Module): penalty=1.0e-04, name=self.name) + dropout = float(self.dropout) + if dropout > 0.0: + # the 0:1, done on the axis of size 'downsample', selects just + # one dimension while keeping the dim. We'll then broadcast when + # we multiply. + dropout_mask = torch.rand_like(scores[:, 0:1]) > dropout + scores = scores * dropout_mask + weights = scores.softmax(dim=1) # ans1 is the first `in_channels` channels of the output