diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index f81062256..c29f71d1a 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -93,12 +93,12 @@ class Subformer2(EncoderInterface): num_encoder_layers: Union[int, Tuple[int]] = 4, encoder_unmasked_dim: Union[int, Tuple[int]] = 256, query_head_dim: Union[int, Tuple[int]] = 24, - pos_head_dim: Union[int, Tuple[int]] = 4, value_head_dim: Union[int, Tuple[int]] = 12, num_heads: Union[int, Tuple[int]] = 8, feedforward_dim: Union[int, Tuple[int]] = 1536, memory_dim: int = -1, - pos_dim: int = 192, + pos_emb_dim: int = 192, + pos_dim: int = 4, dropout: FloatLike = None, # see code below for default warmup_batches: float = 4000.0, causal: bool = False, @@ -127,7 +127,6 @@ class Subformer2(EncoderInterface): num_encoder_layers = _to_tuple(num_encoder_layers) query_head_dim = _to_tuple(query_head_dim) value_head_dim = _to_tuple(value_head_dim) - pos_head_dim = _to_tuple(pos_head_dim) num_heads = _to_tuple(num_heads) feedforward_dim = _to_tuple(feedforward_dim) @@ -145,7 +144,6 @@ class Subformer2(EncoderInterface): pos_dim=pos_dim, num_heads=num_heads[i], query_head_dim=query_head_dim[i], - pos_head_dim=pos_head_dim[i], value_head_dim=value_head_dim[i], feedforward_dim=feedforward_dim[i], memory_dim=memory_dim, @@ -175,6 +173,9 @@ class Subformer2(EncoderInterface): encoders.append(encoder) + self.encoder_pos = CompactRelPositionalEncoding(pos_emb_dim, pos_dim, dropout_rate=0.15, + length_factor=1.0) + self.encoders = nn.ModuleList(encoders) self.downsample_output = SimpleDownsample(max(encoder_dim), @@ -272,7 +273,7 @@ class Subformer2(EncoderInterface): outputs = [] feature_masks = self.get_feature_masks(x) - attn_mask = self._get_attn_mask(x) + attn_offset = self._get_attn_offset(x) if self.training and memory is not None: batch_size = x.shape[1] @@ -282,16 +283,19 @@ class Subformer2(EncoderInterface): memory = memory * (torch.rand(batch_size, 1, device=memory.device) > memory_dropout_rate) + pos_emb = self.encoder_pos(x) + for i, module in enumerate(self.encoders): ds = self.downsampling_factor[i] x = convert_num_channels(x, self.encoder_dim[i]) x = module(x, + pos_emb, chunk_size=chunk_size, feature_mask=feature_masks[i], src_key_padding_mask=(None if src_key_padding_mask is None else src_key_padding_mask[...,::ds]), - attn_mask=attn_mask, + attn_offset=attn_offset, memory=memory, memory_key_padding_mask=memory_key_padding_mask, ) @@ -324,11 +328,11 @@ class Subformer2(EncoderInterface): return x, lengths - def _get_attn_mask(self, x: Tensor) -> Optional[Tensor]: + def _get_attn_offset(self, x: Tensor) -> Optional[Tensor]: """ - Return None if not self.causal is false else return attention mask of shape - (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True - means a masked position. + Return attention offset of shape (1, seq_len, seq_len), interpreted as (tgt_seq_len, + src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros. + Args: x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). chunk_size: chunk size, must divide @@ -345,7 +349,11 @@ class Subformer2(EncoderInterface): attn_mask = (src_c > tgt_c) - return attn_mask + ans = torch.zeros(1, seq_len, seq_len, device=x.device) + + ans.masked_fill(attn_mask, float('-inf')) + + return ans @@ -379,7 +387,7 @@ class Subformer2EncoderLayer(nn.Module): pos_dim: int, num_heads: int, query_head_dim: int, - pos_head_dim: int, + pos_dim: int, value_head_dim: int, feedforward_dim: int, dropout: FloatLike = 0.1, @@ -416,8 +424,8 @@ class Subformer2EncoderLayer(nn.Module): self.const_attention_rate = copy.deepcopy(const_attention_rate) self.self_attn_weights = RelPositionMultiheadAttentionWeights( - embed_dim, pos_dim=pos_dim, num_heads=num_heads, - query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, + embed_dim, num_heads=num_heads, + query_head_dim=query_head_dim, pos_dim=pos_dim, dropout=0.0, ) @@ -552,7 +560,7 @@ class Subformer2EncoderLayer(nn.Module): src: Tensor, pos_emb: Tensor, chunk_size: int = -1, - attn_mask: Optional[Tensor] = None, + attn_offset: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, memory: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, @@ -565,10 +573,9 @@ class Subformer2EncoderLayer(nn.Module): chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. feature_mask: something that broadcasts with src, that we'll multiply `src` by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len). -inf for masked position. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means masked position. May be None. Returns: @@ -583,7 +590,7 @@ class Subformer2EncoderLayer(nn.Module): attn_weights = self.self_attn_weights( src, pos_emb=pos_emb, - attn_mask=attn_mask, + attn_offset=attn_offset, key_padding_mask=src_key_padding_mask, ) @@ -675,9 +682,6 @@ class Subformer2Encoder(nn.Module): final_layerdrop_rate: float = 0.05, ) -> None: super().__init__() - self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15, - length_factor=1.0) - self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) @@ -699,7 +703,7 @@ class Subformer2Encoder(nn.Module): src: Tensor, chunk_size: int = -1, feature_mask: Union[Tensor, float] = 1.0, - attn_mask: Optional[Tensor] = None, + attn_offset: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, memory: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, @@ -711,9 +715,9 @@ class Subformer2Encoder(nn.Module): chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. feature_mask: something that broadcasts with src, that we'll multiply `src` by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. + attn_offset: the attention offset (does masking and related tasks), of shape + broadcasting with (batch_size, seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len). src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means masked position. May be None. memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) @@ -735,7 +739,7 @@ class Subformer2Encoder(nn.Module): output, pos_emb, chunk_size=chunk_size, - attn_mask=attn_mask, + attn_offset=attn_offset, src_key_padding_mask=src_key_padding_mask, memory=memory, memory_key_padding_mask=memory_key_padding_mask, @@ -803,10 +807,227 @@ class BypassModule(nn.Module): return src_orig + (src - src_orig) * bypass_scale +class LearnedDownsamplingModule(nn.Module): + """ + Module that allows you to choose which frames to keep for transformer-type + modules. Effectively downsampling, but not necessarily "evenly"- you just + keep some proportion of frames determined by the embedding. + + Args: + embed_dim: embedding dimension + downsampling_factor: factor to downsample by, e.g. 2 or 4. There is no + fundamental reason why this has to be an integer, but we make it so + anyway. + intermediate_rate: the proportion of the downsampled values that have + "intermediate weights"- between kept and downsampled. The user is + supposed to use these in such a way that if the weight we return is + 0.0, it's equivalent to not using this frame at all. + """ + def __init__(self, + embed_dim: int, + downsampling_factor: int, + intermediate_rate: FloatLike = 0.2): + self.to_scores = nn.Linear(embed_dim, 1, bias=False) + # score_balancer is just to keep the magnitudes of the scores in + # a fixed range and keep them balanced around zero, to stop + # these drifting around. + self.score_balancer = Balancer(1, channel_dim=-1, + min_positive=0.4, max_positive=0.6 + min_abs=1.0, max_abs=1.2, + prob=0.025) + + self.downsampling_factor = downsampling_factor + self.intermediate_rate = copy.deepcopy(intermediate_rate) + + + def forward(self, + x: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + x: a Tensor of shape (seq_len, batch_size, embed_dim) + + Returns: (frame_indexes, weights, kept) + + frame_indexes: a Tensor of integer type, of shape (batch_size, reduced_seq_len) + where reduced_seq_len = (seq_len + d - 1) // d. It contains elements + 0 <= frame_indees < seq_len, in sorted (increasing) order + + weights: a Tensor of shape (batch_size, reduced_seq_len), + corresponding to the kept frames; these will be between 0 and 1, but + mostly exactly 1. + """ + (seq_len, batch_size, _) + scores = self.to_scores(x) # (seq_len, batch_size, 1) + scores = self.score_balancer(scores) + + scores = scores.squeeze(-1).t() # (batch_size, seq_len) + + # indexes, sscores: (batch_size, seq_len) + indexes, sscores = scores.sort(dim=-1, descending=True) + + d = self.downsampling_factor + seq_len_reduced = (seq_len + d - 1) // d + + # TODO: if seq_len / downsampling_factor <= 2, do something special. + + # 'right' is the rightmost of the 2 limits; we want the scores indexed + # 'upper' to be mapped to around 0.0 + right = seq_len_reduced + # we want scores around 'left' to be mapped to around 1.0. + left = int(seq_len_reduced * (1.0 - self.intermediate_rate)) + + # 'collar' determines the range of positions in the sorted list that we use to + # compute the average. We could let collar be 0.0, which would more exactly + # accomplish what we want; but we don't, because this would cause too-noisy + # gradients, with too much gradient going to one frame. + collar = max(1, int(seq_len_reduced * 0.5 * self.intermediate_rate)) + + # right_avg: shape (batch_size,), this is to be mapped to 0.0 + right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1) + + # left_avg: shape (batch_size,), this is to be mapped to 1.0 + left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1) + + # the + 0.001 is to avoid possible division by zero in case of ties. + weights = (sscores - right_avg) / (left_avg - right_avg + 0.001) + weights = weights.clamp(min=0.0, max=1.0) + + indexes = indexes[:, :seq_len_reduced] + weights = weights[:, :seq_len_reduced] + + # re-sort the indexes we kept, on index value, so that + # masking for causal models will be in the correct order. + + indexes, reorder = indexes.sort(dim=-1) + weights = torch.gather(weights, dim=-1, index=reorder) + + x_downsampled = downsample(indexes, x) + return indexes, weights, x_downsampled + + + def downsample(x: Tensor, indexes: Tensor) -> Tensor: + """ + Downsamples x via indexing with the indexes obtained from the + forward() function. + + Args: + x: tensor of shape (seq_len, batch_size, num_channels) + indexes: integer indexes of shape (batch_size, seq_len_reduced), with elements + 0 <= indexes < seq_len. + Returns: + x_downsampled, of shape (seq_len_reduced, batch_size, num_channels) + """ + indexes = indexes.t().unsqueeze(-1).expand(-1, -1, x.shape[-1]) + # indexes now: (seq_len_reduced, batch_size, num_channels) + ans = torch.gather(x, dim=0, index=indexes) + + if __name__ == __main__: + # temp, for testing + x_reconstructed = upsample(x, ans, indexes) + assert torch.allclose(x, x_reconstructed) + + return ans + + + def downsample_pos_emb(pos_emb: Tensor, indexes: Tensor) -> Tensor: + """ + Downsample positional embedding tensor with the provided indexes. + Args: + pos_emb: (batch_size, seq_len, seq_len, pos_dim) + interpreted as (batch_size, tgt_seq_len, src_seq_len, pos_dim). + indexes: (batch_size, seq_len_reduced), containing integer elements + 0 <= indexes < seq_len. + Returns: + downsampled_pos_len: (batch_size, seq_len_reduced, seq_len_reduced, pos_dim) + """ + + (batch_size, seq_len_reduced) = indexes.shape + (_, _, seq_len, pos_dim) = pos_emb.shape + + tgt_indexes = indexes.reshape(batch_size, seq_len_reduced, 1, 1).expand( + batch_size, seq_len_reduced, seq_len, pos_dim) + + pos_emb = torch.gather(pos_emb, dim=1, index=tgt_indexes) + # now pos_emb: (batch_size, seq_len_reduced, seq_len, pos_dim) + + src_indexes = indexes.reshape(batch_size, 1, seq_len_reduced, 1).expand( + batch_size, seq_len_reduced, seq_len_reduced, pos_dim) + + pos_emb = torch.gather(pos_emb, dim=2, index=src_indexes) + # now pos_emb: (batch_size, seq_len_reduced, seq_len_reduced, pos_dim) + return pos_emb + + + def downsample_attn_offset(attn_offset: Tensor, + indexes: Tensor, + weights: Tensor, + eps: float = 1.0e-05) -> Tensor: + """ + Downsamples attn_offset and also modifies it to account for the weights in `weights`. + Args: + attn_offset: a Tensor of shape (1 or batch_size, seq_len, seq_len), interpreted as + (1 or batch_size, tgt_seq_len, src_seq_len) + indexes: a Tensor of shape (batch_size, reduced_seq_len) containing elements + 0 <= indexes < seq_len. + weights: a Tensor of shape (batch_size, reduced_seq_len) containing weights + between 0 and 1; most will be 1. + Returns: + attn_offset_downsampled, a Tensor of shape (batch_size, reduced_seq_len, reduced_seq_len) + """ + (batch_size, seq_len_reduced) = indexes.shape + seq_len = attn_offset.shape[-1] + assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len) + attn_offset = attn_offset.expand(batch_size, seq_len, seq_len) + + + attn_offset = attn_offset.gather(dim=1, src=indices.unsqueeze(-1).expand( + batch_size, seq_len_reduced, seq_len)) + attn_offset = attn_offset.gather(dim=2, src=indices.unsqueeze(1).expand( + batch_size, seq_len_reduced, seq_len_reduced)) + # unsqueeze at position 1 so the extra cost relates to the source position. + attn_offset = attn_offset + weights.clamp(min=eps).log().unsqueeze(1) + + return attn_offst + + + def upsample(x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor: + """ + Upsamples, reversing the downsample() operation and filling in + any not-chosen frames with their original value before downsampling + (or with whatever x_orig contains). + + Args: + x_orig: (seq_len, batch_size, num_channels) + x: (seq_len_reduced, batch_size, num_channels) + indexes: (batch_size, seq_len_reduced), contains original frame indexes + + Downsamples x via indexing with the indexes obtained from the + forward() function. + + Args: + x: tensor of shape (seq_len, batch_size, indexes) + indexes: integer indexes of shape (batch_size, seq_len_reduced), with elements + 0 <= indexes < seq_len. + """ + (seq_len, batch_size, num_channels) = x_orig.shape + + not_kept = torch.ones(batch_size, seq_len, dtype=torch.bool, + device=x.device) + not_kept.scatter_(src=False, dim=1, index=indexes) + + indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels) + # indexes now: (seq_len_reduced, batch_size, num_channels) + + ans = torch.zeros_like(x_orig) + + ans.scatter_(x, dim=0, index=indexes) + + # add in x_orig in the frames that were not originally kept. + return ans + x_orig * not_kept.t().unsqueeze(-1) class DownsampledSubformer2Encoder(nn.Module): - r""" + """ DownsampledSubformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. @@ -818,18 +1039,18 @@ class DownsampledSubformer2Encoder(nn.Module): dropout: FloatLike): super(DownsampledSubformer2Encoder, self).__init__() self.downsample_factor = downsample - self.downsample = SimpleDownsample(dim, - downsample, dropout) + self.downsampler = LearnedDownsamplingModule(dim, + downsample) self.encoder = encoder - self.upsample = SimpleUpsample(dim, downsample) + self.out_combiner = BypassModule(dim, straight_through_rate=0.025) def forward(self, src: Tensor, - chunk_size: int = -1, + pos_emb: Tensor, feature_mask: Union[Tensor, float] = 1.0, - attn_mask: Optional[Tensor] = None, + attn_offset: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, memory: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, @@ -838,11 +1059,12 @@ class DownsampledSubformer2Encoder(nn.Module): Args: src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: the positional embedding, of shape (batch_size, seq_len, seq_len, pos_dim) feature_mask: something that broadcasts with src, that we'll multiply `src` by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + attn_offset: the attention offset, added to scores for attention of shape + (batch_size, seq_len, seq_len) or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means masked position. May be None. memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) @@ -853,16 +1075,20 @@ class DownsampledSubformer2Encoder(nn.Module): Returns: a Tensor with the same shape as src. """ src_orig = src - src = self.downsample(src) - ds = self.downsample_factor - if attn_mask is not None: - attn_mask = attn_mask[::ds,::ds] + indexes, weights, src = self.downsampler(src) + + pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes) + + attn_offset = self.downsample.downsample_attn_offset(attn_offset, + indexes, + weights.clamp(min=1.0e-05)) + src = self.encoder( src, - chunk_size=chunk_size // ds, + os_emb, feature_mask=feature_mask, - attn_mask=attn_mask, + attn_offset=attn_offset, src_key_padding_mask=src_key_padding_mask, memory=memory, memory_key_padding_mask=memory_key_padding_mask, @@ -875,77 +1101,6 @@ class DownsampledSubformer2Encoder(nn.Module): -class SimpleDownsample(torch.nn.Module): - """ - Does downsampling with attention, by weighted sum, and a projection.. - """ - def __init__(self, - channels: int, - downsample: int, - dropout: FloatLike): - super(SimpleDownsample, self).__init__() - - self.bias = nn.Parameter(torch.zeros(downsample)) - - self.name = None # will be set from training code - self.dropout = copy.deepcopy(dropout) - - self.downsample = downsample - - def forward(self, - src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, channels) - """ - (seq_len, batch_size, in_channels) = src.shape - ds = self.downsample - d_seq_len = (seq_len + ds - 1) // ds - - # Pad to an exact multiple of self.downsample - if seq_len != d_seq_len * ds: - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds - - src = src.reshape(d_seq_len, ds, batch_size, in_channels) - - weights = self.bias.softmax(dim=0) - # weights: (downsample, 1, 1) - weights = weights.unsqueeze(-1).unsqueeze(-1) - - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) - - return ans - - -class SimpleUpsample(torch.nn.Module): - """ - A very simple form of upsampling that mostly just repeats the input, but - also adds a position-specific bias. - """ - def __init__(self, - num_channels: int, - upsample: int): - super(SimpleUpsample, self).__init__() - self.upsample = upsample - - def forward(self, - src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*upsample), batch_size, num_channels) - """ - upsample = self.upsample - (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src.reshape(seq_len * upsample, batch_size, num_channels) - return src class CompactRelPositionalEncoding(torch.nn.Module): @@ -967,17 +1122,19 @@ class CompactRelPositionalEncoding(torch.nn.Module): Args: - embed_dim: Embedding dimension. + embed_dim: Temporary embedding dimension used inside this module dropout_rate: Dropout rate. max_len: Maximum input length: just a heuristic for initialization. length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives less weight to small differences of offset near the origin. + pos_dim: dimension at the output of this module. """ def __init__( - self, embed_dim: int, + self, embed_dim: int, dropout_rate: FloatLike, max_len: int = 1000, length_factor: float = 1.0, + pos_dim: int = 4, ) -> None: """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() @@ -989,6 +1146,11 @@ class CompactRelPositionalEncoding(torch.nn.Module): self.length_factor = length_factor self.extend_pe(torch.tensor(0.0).expand(max_len)) + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(embed_dim, + pos_dim, + bias=False, + initial_scale=0.05) def extend_pe(self, x: Tensor) -> None: @@ -1046,27 +1208,45 @@ class CompactRelPositionalEncoding(torch.nn.Module): """Create positional encoding. Args: - x (torch.Tensor): Input tensor (time, batch, `*`). + x (torch.Tensor): Input tensor (time, batch, num_channels_in) Returns: - positional embedding, of shape (1, 2*time-1, `*`). + positional embedding, of shape (1, 2*time-1, pos_dim). """ self.extend_pe(x) + seq_len = x.size(0) pos_emb = self.pe[ self.pe.size(0) // 2 - - x.size(0) + - seq_len, + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0), + + seq_len, : ] pos_emb = pos_emb.unsqueeze(0) - return self.dropout(pos_emb) + pos_emb = self.dropout(pos_emb) + pos_emb = self.linear_pos(pos_emb) + + # currenly pos_emb: (1, 2*seq_len-1, pos_dim) + pos_dim = pos_emb.shape[-1] + batch_size = x.size(1) + (_, seq_stride, channel_stride) = pos_emb.stride() + # it doesn't really matter which one we make positive and which negative here, it + # would just flip the meaning of the embedding. + pos_emb = pos_emb.as_strided((batch_size, seq_len, seq_len, pos_dim), + (0, -seq_stride, seq_stride, channel_stride), + storage_offset=seq_stride * (seqs_len - 1)) + + return pos_emb # (batch_size, seq_len, seq_len, pos_dim) + class RelPositionMultiheadAttentionWeights(nn.Module): - r"""Module that computes multi-head attention weights with relative position encoding. + r"""Module that computes multi-head attention weights with relative position encoding; + in this version, the positions for each frame are passed in (in order to support + + Various other modules consume the resulting attention weights: see, for example, the SimpleAttention module which allows you to compute conventional attention. @@ -1076,22 +1256,20 @@ class RelPositionMultiheadAttentionWeights(nn.Module): Args: embed_dim: number of channels at the input to this module, e.g. 256 - pos_dim: dimension of the positional encoding vectors, e.g. 128. num_heads: number of heads to compute weights for, e.g. 8 query_head_dim: dimension of the query (and key), per head. e.g. 24. - pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + pos_dim: dimension of the projected positional encoding, e.g. 4. dropout: dropout probability for attn_output_weights. Default: 0.0. - pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on - any given call to forward(), in training time. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. """ def __init__( self, embed_dim: int, - pos_dim: int, num_heads: int, query_head_dim: int, - pos_head_dim: int, + pos_dim: int, dropout: float = 0.0, pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)) @@ -1100,13 +1278,13 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self.embed_dim = embed_dim self.num_heads = num_heads self.query_head_dim = query_head_dim - self.pos_head_dim = pos_head_dim + self.pos_dim = pos_dim self.dropout = dropout self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) self.name = None # will be overwritten in training code; for diagnostics. key_head_dim = query_head_dim - in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5 that has been used in previous forms of attention, @@ -1138,13 +1316,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module): prob=0.025) - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(pos_dim, - num_heads * pos_head_dim, - bias=False, - initial_scale=0.05) - - # the following are for diagnosics only, see --print-diagnostics option self.copy_pos_query = Identity() self.copy_query = Identity() @@ -1154,27 +1325,30 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self, x: Tensor, pos_emb: Tensor, - chunk_size: int = -1, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, + attn_offset: Optional[Tensor] = None, + pos_emb: Tensor, + quadratic_pos_weight: Tensor, ) -> Tensor: r""" Args: x: input of shape (seq_len, batch_size, embed_dim) pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 2, pos_dim) - chunk_size - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), - interpreted as ([batch_size,] tgt_seq_len, src_seq_len) - saying which positions are allowed to attend to which other positions. + + attn_offset: a Tensor of shape broadcasting with (batch_size, seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len), if provided this + contains values (probably <= 0) to be added to the logprobs of the attention; + this may combine the log of 'weights' of ChooseDownsamplingModule with + any attn_mask that enforces causality. + pos_emb: a Tensor of shape broadcasting with (batch_size, seq_len, seq_len, pos_dim) + (e.g. pos_dim=4), encoding relative positions. + Returns: a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). """ x = self.in_proj(x) query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim + pos_dim = self.pos_dim num_heads = self.num_heads seq_len, batch_size, _ = x.shape @@ -1185,7 +1359,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): k = x[...,query_dim:2*query_dim] # p is the position-encoding query p = x[...,2*query_dim:] - assert p.shape[-1] == num_heads * pos_head_dim + assert p.shape[-1] == num_heads * pos_dim q = self.copy_query(q) # for diagnostics only, does nothing.