More progress on subformer

This commit is contained in:
Daniel Povey 2023-05-15 10:57:48 +08:00
parent 5c470fe397
commit f740282a1a

View File

@ -93,12 +93,12 @@ class Subformer2(EncoderInterface):
num_encoder_layers: Union[int, Tuple[int]] = 4,
encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
query_head_dim: Union[int, Tuple[int]] = 24,
pos_head_dim: Union[int, Tuple[int]] = 4,
value_head_dim: Union[int, Tuple[int]] = 12,
num_heads: Union[int, Tuple[int]] = 8,
feedforward_dim: Union[int, Tuple[int]] = 1536,
memory_dim: int = -1,
pos_dim: int = 192,
pos_emb_dim: int = 192,
pos_dim: int = 4,
dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0,
causal: bool = False,
@ -127,7 +127,6 @@ class Subformer2(EncoderInterface):
num_encoder_layers = _to_tuple(num_encoder_layers)
query_head_dim = _to_tuple(query_head_dim)
value_head_dim = _to_tuple(value_head_dim)
pos_head_dim = _to_tuple(pos_head_dim)
num_heads = _to_tuple(num_heads)
feedforward_dim = _to_tuple(feedforward_dim)
@ -145,7 +144,6 @@ class Subformer2(EncoderInterface):
pos_dim=pos_dim,
num_heads=num_heads[i],
query_head_dim=query_head_dim[i],
pos_head_dim=pos_head_dim[i],
value_head_dim=value_head_dim[i],
feedforward_dim=feedforward_dim[i],
memory_dim=memory_dim,
@ -175,6 +173,9 @@ class Subformer2(EncoderInterface):
encoders.append(encoder)
self.encoder_pos = CompactRelPositionalEncoding(pos_emb_dim, pos_dim, dropout_rate=0.15,
length_factor=1.0)
self.encoders = nn.ModuleList(encoders)
self.downsample_output = SimpleDownsample(max(encoder_dim),
@ -272,7 +273,7 @@ class Subformer2(EncoderInterface):
outputs = []
feature_masks = self.get_feature_masks(x)
attn_mask = self._get_attn_mask(x)
attn_offset = self._get_attn_offset(x)
if self.training and memory is not None:
batch_size = x.shape[1]
@ -282,16 +283,19 @@ class Subformer2(EncoderInterface):
memory = memory * (torch.rand(batch_size, 1, device=memory.device) >
memory_dropout_rate)
pos_emb = self.encoder_pos(x)
for i, module in enumerate(self.encoders):
ds = self.downsampling_factor[i]
x = convert_num_channels(x, self.encoder_dim[i])
x = module(x,
pos_emb,
chunk_size=chunk_size,
feature_mask=feature_masks[i],
src_key_padding_mask=(None if src_key_padding_mask is None
else src_key_padding_mask[...,::ds]),
attn_mask=attn_mask,
attn_offset=attn_offset,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
@ -324,11 +328,11 @@ class Subformer2(EncoderInterface):
return x, lengths
def _get_attn_mask(self, x: Tensor) -> Optional[Tensor]:
def _get_attn_offset(self, x: Tensor) -> Optional[Tensor]:
"""
Return None if not self.causal is false else return attention mask of shape
(seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True
means a masked position.
Return attention offset of shape (1, seq_len, seq_len), interpreted as (tgt_seq_len,
src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros.
Args:
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
chunk_size: chunk size, must divide
@ -345,7 +349,11 @@ class Subformer2(EncoderInterface):
attn_mask = (src_c > tgt_c)
return attn_mask
ans = torch.zeros(1, seq_len, seq_len, device=x.device)
ans.masked_fill(attn_mask, float('-inf'))
return ans
@ -379,7 +387,7 @@ class Subformer2EncoderLayer(nn.Module):
pos_dim: int,
num_heads: int,
query_head_dim: int,
pos_head_dim: int,
pos_dim: int,
value_head_dim: int,
feedforward_dim: int,
dropout: FloatLike = 0.1,
@ -416,8 +424,8 @@ class Subformer2EncoderLayer(nn.Module):
self.const_attention_rate = copy.deepcopy(const_attention_rate)
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
embed_dim, pos_dim=pos_dim, num_heads=num_heads,
query_head_dim=query_head_dim, pos_head_dim=pos_head_dim,
embed_dim, num_heads=num_heads,
query_head_dim=query_head_dim, pos_dim=pos_dim,
dropout=0.0,
)
@ -552,7 +560,7 @@ class Subformer2EncoderLayer(nn.Module):
src: Tensor,
pos_emb: Tensor,
chunk_size: int = -1,
attn_mask: Optional[Tensor] = None,
attn_offset: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
@ -565,10 +573,9 @@ class Subformer2EncoderLayer(nn.Module):
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
True means masked position. May be None.
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len). -inf for masked position.
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
Returns:
@ -583,7 +590,7 @@ class Subformer2EncoderLayer(nn.Module):
attn_weights = self.self_attn_weights(
src,
pos_emb=pos_emb,
attn_mask=attn_mask,
attn_offset=attn_offset,
key_padding_mask=src_key_padding_mask,
)
@ -675,9 +682,6 @@ class Subformer2Encoder(nn.Module):
final_layerdrop_rate: float = 0.05,
) -> None:
super().__init__()
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15,
length_factor=1.0)
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
@ -699,7 +703,7 @@ class Subformer2Encoder(nn.Module):
src: Tensor,
chunk_size: int = -1,
feature_mask: Union[Tensor, float] = 1.0,
attn_mask: Optional[Tensor] = None,
attn_offset: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
@ -711,9 +715,9 @@ class Subformer2Encoder(nn.Module):
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
True means masked position. May be None.
attn_offset: the attention offset (does masking and related tasks), of shape
broadcasting with (batch_size, seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len).
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
@ -735,7 +739,7 @@ class Subformer2Encoder(nn.Module):
output,
pos_emb,
chunk_size=chunk_size,
attn_mask=attn_mask,
attn_offset=attn_offset,
src_key_padding_mask=src_key_padding_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
@ -803,10 +807,227 @@ class BypassModule(nn.Module):
return src_orig + (src - src_orig) * bypass_scale
class LearnedDownsamplingModule(nn.Module):
"""
Module that allows you to choose which frames to keep for transformer-type
modules. Effectively downsampling, but not necessarily "evenly"- you just
keep some proportion of frames determined by the embedding.
Args:
embed_dim: embedding dimension
downsampling_factor: factor to downsample by, e.g. 2 or 4. There is no
fundamental reason why this has to be an integer, but we make it so
anyway.
intermediate_rate: the proportion of the downsampled values that have
"intermediate weights"- between kept and downsampled. The user is
supposed to use these in such a way that if the weight we return is
0.0, it's equivalent to not using this frame at all.
"""
def __init__(self,
embed_dim: int,
downsampling_factor: int,
intermediate_rate: FloatLike = 0.2):
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
# score_balancer is just to keep the magnitudes of the scores in
# a fixed range and keep them balanced around zero, to stop
# these drifting around.
self.score_balancer = Balancer(1, channel_dim=-1,
min_positive=0.4, max_positive=0.6
min_abs=1.0, max_abs=1.2,
prob=0.025)
self.downsampling_factor = downsampling_factor
self.intermediate_rate = copy.deepcopy(intermediate_rate)
def forward(self,
x: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
x: a Tensor of shape (seq_len, batch_size, embed_dim)
Returns: (frame_indexes, weights, kept)
frame_indexes: a Tensor of integer type, of shape (batch_size, reduced_seq_len)
where reduced_seq_len = (seq_len + d - 1) // d. It contains elements
0 <= frame_indees < seq_len, in sorted (increasing) order
weights: a Tensor of shape (batch_size, reduced_seq_len),
corresponding to the kept frames; these will be between 0 and 1, but
mostly exactly 1.
"""
(seq_len, batch_size, _)
scores = self.to_scores(x) # (seq_len, batch_size, 1)
scores = self.score_balancer(scores)
scores = scores.squeeze(-1).t() # (batch_size, seq_len)
# indexes, sscores: (batch_size, seq_len)
indexes, sscores = scores.sort(dim=-1, descending=True)
d = self.downsampling_factor
seq_len_reduced = (seq_len + d - 1) // d
# TODO: if seq_len / downsampling_factor <= 2, do something special.
# 'right' is the rightmost of the 2 limits; we want the scores indexed
# 'upper' to be mapped to around 0.0
right = seq_len_reduced
# we want scores around 'left' to be mapped to around 1.0.
left = int(seq_len_reduced * (1.0 - self.intermediate_rate))
# 'collar' determines the range of positions in the sorted list that we use to
# compute the average. We could let collar be 0.0, which would more exactly
# accomplish what we want; but we don't, because this would cause too-noisy
# gradients, with too much gradient going to one frame.
collar = max(1, int(seq_len_reduced * 0.5 * self.intermediate_rate))
# right_avg: shape (batch_size,), this is to be mapped to 0.0
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1)
# left_avg: shape (batch_size,), this is to be mapped to 1.0
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1)
# the + 0.001 is to avoid possible division by zero in case of ties.
weights = (sscores - right_avg) / (left_avg - right_avg + 0.001)
weights = weights.clamp(min=0.0, max=1.0)
indexes = indexes[:, :seq_len_reduced]
weights = weights[:, :seq_len_reduced]
# re-sort the indexes we kept, on index value, so that
# masking for causal models will be in the correct order.
indexes, reorder = indexes.sort(dim=-1)
weights = torch.gather(weights, dim=-1, index=reorder)
x_downsampled = downsample(indexes, x)
return indexes, weights, x_downsampled
def downsample(x: Tensor, indexes: Tensor) -> Tensor:
"""
Downsamples x via indexing with the indexes obtained from the
forward() function.
Args:
x: tensor of shape (seq_len, batch_size, num_channels)
indexes: integer indexes of shape (batch_size, seq_len_reduced), with elements
0 <= indexes < seq_len.
Returns:
x_downsampled, of shape (seq_len_reduced, batch_size, num_channels)
"""
indexes = indexes.t().unsqueeze(-1).expand(-1, -1, x.shape[-1])
# indexes now: (seq_len_reduced, batch_size, num_channels)
ans = torch.gather(x, dim=0, index=indexes)
if __name__ == __main__:
# temp, for testing
x_reconstructed = upsample(x, ans, indexes)
assert torch.allclose(x, x_reconstructed)
return ans
def downsample_pos_emb(pos_emb: Tensor, indexes: Tensor) -> Tensor:
"""
Downsample positional embedding tensor with the provided indexes.
Args:
pos_emb: (batch_size, seq_len, seq_len, pos_dim)
interpreted as (batch_size, tgt_seq_len, src_seq_len, pos_dim).
indexes: (batch_size, seq_len_reduced), containing integer elements
0 <= indexes < seq_len.
Returns:
downsampled_pos_len: (batch_size, seq_len_reduced, seq_len_reduced, pos_dim)
"""
(batch_size, seq_len_reduced) = indexes.shape
(_, _, seq_len, pos_dim) = pos_emb.shape
tgt_indexes = indexes.reshape(batch_size, seq_len_reduced, 1, 1).expand(
batch_size, seq_len_reduced, seq_len, pos_dim)
pos_emb = torch.gather(pos_emb, dim=1, index=tgt_indexes)
# now pos_emb: (batch_size, seq_len_reduced, seq_len, pos_dim)
src_indexes = indexes.reshape(batch_size, 1, seq_len_reduced, 1).expand(
batch_size, seq_len_reduced, seq_len_reduced, pos_dim)
pos_emb = torch.gather(pos_emb, dim=2, index=src_indexes)
# now pos_emb: (batch_size, seq_len_reduced, seq_len_reduced, pos_dim)
return pos_emb
def downsample_attn_offset(attn_offset: Tensor,
indexes: Tensor,
weights: Tensor,
eps: float = 1.0e-05) -> Tensor:
"""
Downsamples attn_offset and also modifies it to account for the weights in `weights`.
Args:
attn_offset: a Tensor of shape (1 or batch_size, seq_len, seq_len), interpreted as
(1 or batch_size, tgt_seq_len, src_seq_len)
indexes: a Tensor of shape (batch_size, reduced_seq_len) containing elements
0 <= indexes < seq_len.
weights: a Tensor of shape (batch_size, reduced_seq_len) containing weights
between 0 and 1; most will be 1.
Returns:
attn_offset_downsampled, a Tensor of shape (batch_size, reduced_seq_len, reduced_seq_len)
"""
(batch_size, seq_len_reduced) = indexes.shape
seq_len = attn_offset.shape[-1]
assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len)
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
attn_offset = attn_offset.gather(dim=1, src=indices.unsqueeze(-1).expand(
batch_size, seq_len_reduced, seq_len))
attn_offset = attn_offset.gather(dim=2, src=indices.unsqueeze(1).expand(
batch_size, seq_len_reduced, seq_len_reduced))
# unsqueeze at position 1 so the extra cost relates to the source position.
attn_offset = attn_offset + weights.clamp(min=eps).log().unsqueeze(1)
return attn_offst
def upsample(x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor:
"""
Upsamples, reversing the downsample() operation and filling in
any not-chosen frames with their original value before downsampling
(or with whatever x_orig contains).
Args:
x_orig: (seq_len, batch_size, num_channels)
x: (seq_len_reduced, batch_size, num_channels)
indexes: (batch_size, seq_len_reduced), contains original frame indexes
Downsamples x via indexing with the indexes obtained from the
forward() function.
Args:
x: tensor of shape (seq_len, batch_size, indexes)
indexes: integer indexes of shape (batch_size, seq_len_reduced), with elements
0 <= indexes < seq_len.
"""
(seq_len, batch_size, num_channels) = x_orig.shape
not_kept = torch.ones(batch_size, seq_len, dtype=torch.bool,
device=x.device)
not_kept.scatter_(src=False, dim=1, index=indexes)
indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
# indexes now: (seq_len_reduced, batch_size, num_channels)
ans = torch.zeros_like(x_orig)
ans.scatter_(x, dim=0, index=indexes)
# add in x_orig in the frames that were not originally kept.
return ans + x_orig * not_kept.t().unsqueeze(-1)
class DownsampledSubformer2Encoder(nn.Module):
r"""
"""
DownsampledSubformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
after convolutional downsampling, and then upsampled again at the output, and combined
with the origin input, so that the output has the same shape as the input.
@ -818,18 +1039,18 @@ class DownsampledSubformer2Encoder(nn.Module):
dropout: FloatLike):
super(DownsampledSubformer2Encoder, self).__init__()
self.downsample_factor = downsample
self.downsample = SimpleDownsample(dim,
downsample, dropout)
self.downsampler = LearnedDownsamplingModule(dim,
downsample)
self.encoder = encoder
self.upsample = SimpleUpsample(dim, downsample)
self.out_combiner = BypassModule(dim, straight_through_rate=0.025)
def forward(self,
src: Tensor,
chunk_size: int = -1,
pos_emb: Tensor,
feature_mask: Union[Tensor, float] = 1.0,
attn_mask: Optional[Tensor] = None,
attn_offset: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
@ -838,11 +1059,12 @@ class DownsampledSubformer2Encoder(nn.Module):
Args:
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
pos_emb: the positional embedding, of shape (batch_size, seq_len, seq_len, pos_dim)
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
attn_offset: the attention offset, added to scores for attention of shape
(batch_size, seq_len, seq_len) or (seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
True means masked position. May be None.
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
@ -853,16 +1075,20 @@ class DownsampledSubformer2Encoder(nn.Module):
Returns: a Tensor with the same shape as src.
"""
src_orig = src
src = self.downsample(src)
ds = self.downsample_factor
if attn_mask is not None:
attn_mask = attn_mask[::ds,::ds]
indexes, weights, src = self.downsampler(src)
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
attn_offset = self.downsample.downsample_attn_offset(attn_offset,
indexes,
weights.clamp(min=1.0e-05))
src = self.encoder(
src,
chunk_size=chunk_size // ds,
os_emb,
feature_mask=feature_mask,
attn_mask=attn_mask,
attn_offset=attn_offset,
src_key_padding_mask=src_key_padding_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
@ -875,77 +1101,6 @@ class DownsampledSubformer2Encoder(nn.Module):
class SimpleDownsample(torch.nn.Module):
"""
Does downsampling with attention, by weighted sum, and a projection..
"""
def __init__(self,
channels: int,
downsample: int,
dropout: FloatLike):
super(SimpleDownsample, self).__init__()
self.bias = nn.Parameter(torch.zeros(downsample))
self.name = None # will be set from training code
self.dropout = copy.deepcopy(dropout)
self.downsample = downsample
def forward(self,
src: Tensor) -> Tensor:
"""
x: (seq_len, batch_size, in_channels)
Returns a tensor of shape
( (seq_len+downsample-1)//downsample, batch_size, channels)
"""
(seq_len, batch_size, in_channels) = src.shape
ds = self.downsample
d_seq_len = (seq_len + ds - 1) // ds
# Pad to an exact multiple of self.downsample
if seq_len != d_seq_len * ds:
# right-pad src, repeating the last element.
pad = d_seq_len * ds - seq_len
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
src = torch.cat((src, src_extra), dim=0)
assert src.shape[0] == d_seq_len * ds
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
weights = self.bias.softmax(dim=0)
# weights: (downsample, 1, 1)
weights = weights.unsqueeze(-1).unsqueeze(-1)
# ans1 is the first `in_channels` channels of the output
ans = (src * weights).sum(dim=1)
return ans
class SimpleUpsample(torch.nn.Module):
"""
A very simple form of upsampling that mostly just repeats the input, but
also adds a position-specific bias.
"""
def __init__(self,
num_channels: int,
upsample: int):
super(SimpleUpsample, self).__init__()
self.upsample = upsample
def forward(self,
src: Tensor) -> Tensor:
"""
x: (seq_len, batch_size, num_channels)
Returns a tensor of shape
( (seq_len*upsample), batch_size, num_channels)
"""
upsample = self.upsample
(seq_len, batch_size, num_channels) = src.shape
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
src = src.reshape(seq_len * upsample, batch_size, num_channels)
return src
class CompactRelPositionalEncoding(torch.nn.Module):
@ -967,17 +1122,19 @@ class CompactRelPositionalEncoding(torch.nn.Module):
Args:
embed_dim: Embedding dimension.
embed_dim: Temporary embedding dimension used inside this module
dropout_rate: Dropout rate.
max_len: Maximum input length: just a heuristic for initialization.
length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
less weight to small differences of offset near the origin.
pos_dim: dimension at the output of this module.
"""
def __init__(
self, embed_dim: int,
self, embed_dim: int,
dropout_rate: FloatLike,
max_len: int = 1000,
length_factor: float = 1.0,
pos_dim: int = 4,
) -> None:
"""Construct a CompactRelPositionalEncoding object."""
super(CompactRelPositionalEncoding, self).__init__()
@ -989,6 +1146,11 @@ class CompactRelPositionalEncoding(torch.nn.Module):
self.length_factor = length_factor
self.extend_pe(torch.tensor(0.0).expand(max_len))
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim,
pos_dim,
bias=False,
initial_scale=0.05)
def extend_pe(self, x: Tensor) -> None:
@ -1046,27 +1208,45 @@ class CompactRelPositionalEncoding(torch.nn.Module):
"""Create positional encoding.
Args:
x (torch.Tensor): Input tensor (time, batch, `*`).
x (torch.Tensor): Input tensor (time, batch, num_channels_in)
Returns:
positional embedding, of shape (1, 2*time-1, `*`).
positional embedding, of shape (1, 2*time-1, pos_dim).
"""
self.extend_pe(x)
seq_len = x.size(0)
pos_emb = self.pe[
self.pe.size(0) // 2
- x.size(0)
- seq_len,
+ 1 : self.pe.size(0) // 2 # noqa E203
+ x.size(0),
+ seq_len,
:
]
pos_emb = pos_emb.unsqueeze(0)
return self.dropout(pos_emb)
pos_emb = self.dropout(pos_emb)
pos_emb = self.linear_pos(pos_emb)
# currenly pos_emb: (1, 2*seq_len-1, pos_dim)
pos_dim = pos_emb.shape[-1]
batch_size = x.size(1)
(_, seq_stride, channel_stride) = pos_emb.stride()
# it doesn't really matter which one we make positive and which negative here, it
# would just flip the meaning of the embedding.
pos_emb = pos_emb.as_strided((batch_size, seq_len, seq_len, pos_dim),
(0, -seq_stride, seq_stride, channel_stride),
storage_offset=seq_stride * (seqs_len - 1))
return pos_emb # (batch_size, seq_len, seq_len, pos_dim)
class RelPositionMultiheadAttentionWeights(nn.Module):
r"""Module that computes multi-head attention weights with relative position encoding.
r"""Module that computes multi-head attention weights with relative position encoding;
in this version, the positions for each frame are passed in (in order to support
Various other modules consume the resulting attention weights: see, for example, the
SimpleAttention module which allows you to compute conventional attention.
@ -1076,22 +1256,20 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
Args:
embed_dim: number of channels at the input to this module, e.g. 256
pos_dim: dimension of the positional encoding vectors, e.g. 128.
num_heads: number of heads to compute weights for, e.g. 8
query_head_dim: dimension of the query (and key), per head. e.g. 24.
pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
pos_dim: dimension of the projected positional encoding, e.g. 4.
dropout: dropout probability for attn_output_weights. Default: 0.0.
pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
any given call to forward(), in training time.
pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
any given call to forward(), in training time.
"""
def __init__(
self,
embed_dim: int,
pos_dim: int,
num_heads: int,
query_head_dim: int,
pos_head_dim: int,
pos_dim: int,
dropout: float = 0.0,
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5),
(4000.0, 0.0))
@ -1100,13 +1278,13 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.query_head_dim = query_head_dim
self.pos_head_dim = pos_head_dim
self.pos_dim = pos_dim
self.dropout = dropout
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
self.name = None # will be overwritten in training code; for diagnostics.
key_head_dim = query_head_dim
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads
# the initial_scale is supposed to take over the "scaling" factor of
# head_dim ** -0.5 that has been used in previous forms of attention,
@ -1138,13 +1316,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
prob=0.025)
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(pos_dim,
num_heads * pos_head_dim,
bias=False,
initial_scale=0.05)
# the following are for diagnosics only, see --print-diagnostics option
self.copy_pos_query = Identity()
self.copy_query = Identity()
@ -1154,27 +1325,30 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self,
x: Tensor,
pos_emb: Tensor,
chunk_size: int = -1,
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
attn_offset: Optional[Tensor] = None,
pos_emb: Tensor,
quadratic_pos_weight: Tensor,
) -> Tensor:
r"""
Args:
x: input of shape (seq_len, batch_size, embed_dim)
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 2, pos_dim)
chunk_size
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
are True in this mask will be ignored as sources in the attention weighting.
attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
saying which positions are allowed to attend to which other positions.
attn_offset: a Tensor of shape broadcasting with (batch_size, seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len), if provided this
contains values (probably <= 0) to be added to the logprobs of the attention;
this may combine the log of 'weights' of ChooseDownsamplingModule with
any attn_mask that enforces causality.
pos_emb: a Tensor of shape broadcasting with (batch_size, seq_len, seq_len, pos_dim)
(e.g. pos_dim=4), encoding relative positions.
Returns:
a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
"""
x = self.in_proj(x)
query_head_dim = self.query_head_dim
pos_head_dim = self.pos_head_dim
pos_dim = self.pos_dim
num_heads = self.num_heads
seq_len, batch_size, _ = x.shape
@ -1185,7 +1359,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
k = x[...,query_dim:2*query_dim]
# p is the position-encoding query
p = x[...,2*query_dim:]
assert p.shape[-1] == num_heads * pos_head_dim
assert p.shape[-1] == num_heads * pos_dim
q = self.copy_query(q) # for diagnostics only, does nothing.