diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 1855e06ae..312ffe0e6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -108,14 +108,12 @@ class Conformer(EncoderInterface): dropout=dropout ), input_dim=d_model[0], - module_dim=d_model[1], output_dim=d_model[1], downsample=conformer_subsampling_factor, ) - self.out_proj = ScaledLinear( - d_model[0] + d_model[1], d_model[2], - bias=False) + self.out_combiner = SimpleCombiner(d_model[0], + d_model[1]) def forward( @@ -158,12 +156,10 @@ class Conformer(EncoderInterface): x1_no_combine, src_key_padding_mask=mask, warmup=warmup ) # (T, N, C) where C == d_model[1] - x = torch.cat((x1, x2), dim=2) + x = self.out_combiner(x1, x2) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - x = self.out_proj(x) - return x, lengths @@ -415,29 +411,14 @@ class DownsampledConformerEncoder(nn.Module): def __init__(self, encoder: nn.Module, input_dim: int, - module_dim: int, output_dim: int, downsample: int): super(DownsampledConformerEncoder, self).__init__() - self.downsample_factor = downsample - - # note: we'll pad manually. - self.downsample = nn.Conv1d( - input_dim, - module_dim, - kernel_size=downsample, - stride=downsample, - padding=0) - + self.downsample = AttentionDownsample(input_dim, output_dim, downsample) self.encoder = encoder + self.upsample = SimpleUpsample(output_dim, downsample) - self.upsample = nn.ConvTranspose1d( - module_dim, - output_dim, - kernel_size=downsample, - stride=downsample, - padding=0) def forward(self, src: Tensor, @@ -462,10 +443,56 @@ class DownsampledConformerEncoder(nn.Module): Returns: output of shape (S, N, F) where F is the number of output features (output_dim to constructor) """ - (seq_len, batch_size, embedding_dim) = src.shape + + src_orig = src + src = self.downsample(src) ds = self.downsample_factor + if mask is not None: + mask = mask[::ds,::ds] + if src_key_padding_mask is not None: + src_key_padding_mask = src_key_padding_mask[::ds] + + src, _src_no_combine = self.encoder( + src, src_key_padding_mask=mask, warmup=warmup + ) + src = self.upsample(src) + + return src + + +class AttentionDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + def __init__(self, + in_channels: int, + out_channels: int, + downsample: int): + """ + Require out_channels > in_channels. + """ + super(AttentionDownsample, self).__init__() + assert out_channels > in_channels + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + + # fill in the extra dimensions with a projection of the input + self.extra_proj = nn.Linear(in_channels * downsample, + out_channels - in_channels, + bias=False) + self.downsample = downsample + + def forward(self, + src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, out_channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample d_seq_len = (seq_len + ds - 1) // ds src_orig = src + # Pad to an exact multiple of self.downsample if seq_len != d_seq_len * ds: # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len @@ -473,31 +500,83 @@ class DownsampledConformerEncoder(nn.Module): src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds - if mask is not None: - mask = mask[::ds,::ds] - if src_key_padding_mask is not None: - src_key_padding_mask = src_key_padding_mask[::ds] + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + scores = (src * self.query).sum(dim=-1, keepdim=True) + weights = scores.softmax(dim=1) - src = src.permute(1, 2, 0) # (#batch, channels, time). - src = self.downsample(src) - src = src.permute(2, 0, 1) # (time, batch, channels) + # ans1 is the first `in_channels` channels of the output + ans1 = (src * weights).sum(dim=1) + src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) + ans2 = self.extra_proj(src) - src, _src_no_combine = self.encoder( - src, src_key_padding_mask=mask, warmup=warmup - ) # (T, N, C) + ans = torch.cat((ans1, ans2), dim=2) + return ans - src = src.permute(1, 2, 0) # (#batch, channels, time). - src = self.upsample(src) - src = src.permute(2, 0, 1) # (time, batch, channels) - new_seq_len = src.shape[0] - assert new_seq_len >= seq_len - if new_seq_len > seq_len: - src = src[:seq_len] +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + def __init__(self, + num_channels: int, + upsample: int): + super(SimpleUpsample, self).__init__() + self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) + def forward(self, + src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.bias.shape[0] + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src + self.bias.unsqueeze(1) + src = src.reshape(seq_len * upsample, batch_size, num_channels) return src +class SimpleCombiner(torch.nn.Module): + """ + A very simple way of combining 2 vectors of 2 different dims, via a + learned weighted combination in the shared part of the dim. + + """ + def __init__(self, + dim1: int, + dim2: int): + super(SimpleCombiner, self).__init__() + assert dim2 > dim1 + self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01) + self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01) + + + def forward(self, + src1: Tensor, + src2: Tensor) -> Tensor: + """ + src1: (*, dim1) + src2: (*, dim2) + + Returns: a tensor of shape (*, dim2) + """ + assert src1.shape[:-1] == src2.shape[:-1] + dim1 = src1.shape[-1] + dim2 = src2.shape[-1] + + weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True) + weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True) + weight = (weight1 + weight2).sigmoid() + + src2_part1 = src2[...,:dim1] + part1 = src1 * weight + src2_part1 * (1.0 - weight) + part2 = src2[...,dim1:] + return torch.cat((part1, part2), dim=-1) + + class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module.