More progress on subformer

This commit is contained in:
Daniel Povey 2023-05-15 10:57:48 +08:00
parent 5c470fe397
commit f740282a1a

View File

@ -93,12 +93,12 @@ class Subformer2(EncoderInterface):
num_encoder_layers: Union[int, Tuple[int]] = 4, num_encoder_layers: Union[int, Tuple[int]] = 4,
encoder_unmasked_dim: Union[int, Tuple[int]] = 256, encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
query_head_dim: Union[int, Tuple[int]] = 24, query_head_dim: Union[int, Tuple[int]] = 24,
pos_head_dim: Union[int, Tuple[int]] = 4,
value_head_dim: Union[int, Tuple[int]] = 12, value_head_dim: Union[int, Tuple[int]] = 12,
num_heads: Union[int, Tuple[int]] = 8, num_heads: Union[int, Tuple[int]] = 8,
feedforward_dim: Union[int, Tuple[int]] = 1536, feedforward_dim: Union[int, Tuple[int]] = 1536,
memory_dim: int = -1, memory_dim: int = -1,
pos_dim: int = 192, pos_emb_dim: int = 192,
pos_dim: int = 4,
dropout: FloatLike = None, # see code below for default dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0, warmup_batches: float = 4000.0,
causal: bool = False, causal: bool = False,
@ -127,7 +127,6 @@ class Subformer2(EncoderInterface):
num_encoder_layers = _to_tuple(num_encoder_layers) num_encoder_layers = _to_tuple(num_encoder_layers)
query_head_dim = _to_tuple(query_head_dim) query_head_dim = _to_tuple(query_head_dim)
value_head_dim = _to_tuple(value_head_dim) value_head_dim = _to_tuple(value_head_dim)
pos_head_dim = _to_tuple(pos_head_dim)
num_heads = _to_tuple(num_heads) num_heads = _to_tuple(num_heads)
feedforward_dim = _to_tuple(feedforward_dim) feedforward_dim = _to_tuple(feedforward_dim)
@ -145,7 +144,6 @@ class Subformer2(EncoderInterface):
pos_dim=pos_dim, pos_dim=pos_dim,
num_heads=num_heads[i], num_heads=num_heads[i],
query_head_dim=query_head_dim[i], query_head_dim=query_head_dim[i],
pos_head_dim=pos_head_dim[i],
value_head_dim=value_head_dim[i], value_head_dim=value_head_dim[i],
feedforward_dim=feedforward_dim[i], feedforward_dim=feedforward_dim[i],
memory_dim=memory_dim, memory_dim=memory_dim,
@ -175,6 +173,9 @@ class Subformer2(EncoderInterface):
encoders.append(encoder) encoders.append(encoder)
self.encoder_pos = CompactRelPositionalEncoding(pos_emb_dim, pos_dim, dropout_rate=0.15,
length_factor=1.0)
self.encoders = nn.ModuleList(encoders) self.encoders = nn.ModuleList(encoders)
self.downsample_output = SimpleDownsample(max(encoder_dim), self.downsample_output = SimpleDownsample(max(encoder_dim),
@ -272,7 +273,7 @@ class Subformer2(EncoderInterface):
outputs = [] outputs = []
feature_masks = self.get_feature_masks(x) feature_masks = self.get_feature_masks(x)
attn_mask = self._get_attn_mask(x) attn_offset = self._get_attn_offset(x)
if self.training and memory is not None: if self.training and memory is not None:
batch_size = x.shape[1] batch_size = x.shape[1]
@ -282,16 +283,19 @@ class Subformer2(EncoderInterface):
memory = memory * (torch.rand(batch_size, 1, device=memory.device) > memory = memory * (torch.rand(batch_size, 1, device=memory.device) >
memory_dropout_rate) memory_dropout_rate)
pos_emb = self.encoder_pos(x)
for i, module in enumerate(self.encoders): for i, module in enumerate(self.encoders):
ds = self.downsampling_factor[i] ds = self.downsampling_factor[i]
x = convert_num_channels(x, self.encoder_dim[i]) x = convert_num_channels(x, self.encoder_dim[i])
x = module(x, x = module(x,
pos_emb,
chunk_size=chunk_size, chunk_size=chunk_size,
feature_mask=feature_masks[i], feature_mask=feature_masks[i],
src_key_padding_mask=(None if src_key_padding_mask is None src_key_padding_mask=(None if src_key_padding_mask is None
else src_key_padding_mask[...,::ds]), else src_key_padding_mask[...,::ds]),
attn_mask=attn_mask, attn_offset=attn_offset,
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
) )
@ -324,11 +328,11 @@ class Subformer2(EncoderInterface):
return x, lengths return x, lengths
def _get_attn_mask(self, x: Tensor) -> Optional[Tensor]: def _get_attn_offset(self, x: Tensor) -> Optional[Tensor]:
""" """
Return None if not self.causal is false else return attention mask of shape Return attention offset of shape (1, seq_len, seq_len), interpreted as (tgt_seq_len,
(seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros.
means a masked position.
Args: Args:
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
chunk_size: chunk size, must divide chunk_size: chunk size, must divide
@ -345,7 +349,11 @@ class Subformer2(EncoderInterface):
attn_mask = (src_c > tgt_c) attn_mask = (src_c > tgt_c)
return attn_mask ans = torch.zeros(1, seq_len, seq_len, device=x.device)
ans.masked_fill(attn_mask, float('-inf'))
return ans
@ -379,7 +387,7 @@ class Subformer2EncoderLayer(nn.Module):
pos_dim: int, pos_dim: int,
num_heads: int, num_heads: int,
query_head_dim: int, query_head_dim: int,
pos_head_dim: int, pos_dim: int,
value_head_dim: int, value_head_dim: int,
feedforward_dim: int, feedforward_dim: int,
dropout: FloatLike = 0.1, dropout: FloatLike = 0.1,
@ -416,8 +424,8 @@ class Subformer2EncoderLayer(nn.Module):
self.const_attention_rate = copy.deepcopy(const_attention_rate) self.const_attention_rate = copy.deepcopy(const_attention_rate)
self.self_attn_weights = RelPositionMultiheadAttentionWeights( self.self_attn_weights = RelPositionMultiheadAttentionWeights(
embed_dim, pos_dim=pos_dim, num_heads=num_heads, embed_dim, num_heads=num_heads,
query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, query_head_dim=query_head_dim, pos_dim=pos_dim,
dropout=0.0, dropout=0.0,
) )
@ -552,7 +560,7 @@ class Subformer2EncoderLayer(nn.Module):
src: Tensor, src: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
chunk_size: int = -1, chunk_size: int = -1,
attn_mask: Optional[Tensor] = None, attn_offset: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None, memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
@ -565,9 +573,8 @@ class Subformer2EncoderLayer(nn.Module):
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
feature_mask: something that broadcasts with src, that we'll multiply `src` feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). interpreted as (batch_size, tgt_seq_len, src_seq_len). -inf for masked position.
True means masked position. May be None.
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None. masked position. May be None.
@ -583,7 +590,7 @@ class Subformer2EncoderLayer(nn.Module):
attn_weights = self.self_attn_weights( attn_weights = self.self_attn_weights(
src, src,
pos_emb=pos_emb, pos_emb=pos_emb,
attn_mask=attn_mask, attn_offset=attn_offset,
key_padding_mask=src_key_padding_mask, key_padding_mask=src_key_padding_mask,
) )
@ -675,9 +682,6 @@ class Subformer2Encoder(nn.Module):
final_layerdrop_rate: float = 0.05, final_layerdrop_rate: float = 0.05,
) -> None: ) -> None:
super().__init__() super().__init__()
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15,
length_factor=1.0)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)] [copy.deepcopy(encoder_layer) for i in range(num_layers)]
) )
@ -699,7 +703,7 @@ class Subformer2Encoder(nn.Module):
src: Tensor, src: Tensor,
chunk_size: int = -1, chunk_size: int = -1,
feature_mask: Union[Tensor, float] = 1.0, feature_mask: Union[Tensor, float] = 1.0,
attn_mask: Optional[Tensor] = None, attn_offset: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None, memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
@ -711,9 +715,9 @@ class Subformer2Encoder(nn.Module):
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
feature_mask: something that broadcasts with src, that we'll multiply `src` feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), attn_offset: the attention offset (does masking and related tasks), of shape
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). broadcasting with (batch_size, seq_len, seq_len),
True means masked position. May be None. interpreted as (batch_size, tgt_seq_len, src_seq_len).
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None. masked position. May be None.
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
@ -735,7 +739,7 @@ class Subformer2Encoder(nn.Module):
output, output,
pos_emb, pos_emb,
chunk_size=chunk_size, chunk_size=chunk_size,
attn_mask=attn_mask, attn_offset=attn_offset,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
@ -803,10 +807,227 @@ class BypassModule(nn.Module):
return src_orig + (src - src_orig) * bypass_scale return src_orig + (src - src_orig) * bypass_scale
class LearnedDownsamplingModule(nn.Module):
"""
Module that allows you to choose which frames to keep for transformer-type
modules. Effectively downsampling, but not necessarily "evenly"- you just
keep some proportion of frames determined by the embedding.
Args:
embed_dim: embedding dimension
downsampling_factor: factor to downsample by, e.g. 2 or 4. There is no
fundamental reason why this has to be an integer, but we make it so
anyway.
intermediate_rate: the proportion of the downsampled values that have
"intermediate weights"- between kept and downsampled. The user is
supposed to use these in such a way that if the weight we return is
0.0, it's equivalent to not using this frame at all.
"""
def __init__(self,
embed_dim: int,
downsampling_factor: int,
intermediate_rate: FloatLike = 0.2):
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
# score_balancer is just to keep the magnitudes of the scores in
# a fixed range and keep them balanced around zero, to stop
# these drifting around.
self.score_balancer = Balancer(1, channel_dim=-1,
min_positive=0.4, max_positive=0.6
min_abs=1.0, max_abs=1.2,
prob=0.025)
self.downsampling_factor = downsampling_factor
self.intermediate_rate = copy.deepcopy(intermediate_rate)
def forward(self,
x: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
x: a Tensor of shape (seq_len, batch_size, embed_dim)
Returns: (frame_indexes, weights, kept)
frame_indexes: a Tensor of integer type, of shape (batch_size, reduced_seq_len)
where reduced_seq_len = (seq_len + d - 1) // d. It contains elements
0 <= frame_indees < seq_len, in sorted (increasing) order
weights: a Tensor of shape (batch_size, reduced_seq_len),
corresponding to the kept frames; these will be between 0 and 1, but
mostly exactly 1.
"""
(seq_len, batch_size, _)
scores = self.to_scores(x) # (seq_len, batch_size, 1)
scores = self.score_balancer(scores)
scores = scores.squeeze(-1).t() # (batch_size, seq_len)
# indexes, sscores: (batch_size, seq_len)
indexes, sscores = scores.sort(dim=-1, descending=True)
d = self.downsampling_factor
seq_len_reduced = (seq_len + d - 1) // d
# TODO: if seq_len / downsampling_factor <= 2, do something special.
# 'right' is the rightmost of the 2 limits; we want the scores indexed
# 'upper' to be mapped to around 0.0
right = seq_len_reduced
# we want scores around 'left' to be mapped to around 1.0.
left = int(seq_len_reduced * (1.0 - self.intermediate_rate))
# 'collar' determines the range of positions in the sorted list that we use to
# compute the average. We could let collar be 0.0, which would more exactly
# accomplish what we want; but we don't, because this would cause too-noisy
# gradients, with too much gradient going to one frame.
collar = max(1, int(seq_len_reduced * 0.5 * self.intermediate_rate))
# right_avg: shape (batch_size,), this is to be mapped to 0.0
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1)
# left_avg: shape (batch_size,), this is to be mapped to 1.0
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1)
# the + 0.001 is to avoid possible division by zero in case of ties.
weights = (sscores - right_avg) / (left_avg - right_avg + 0.001)
weights = weights.clamp(min=0.0, max=1.0)
indexes = indexes[:, :seq_len_reduced]
weights = weights[:, :seq_len_reduced]
# re-sort the indexes we kept, on index value, so that
# masking for causal models will be in the correct order.
indexes, reorder = indexes.sort(dim=-1)
weights = torch.gather(weights, dim=-1, index=reorder)
x_downsampled = downsample(indexes, x)
return indexes, weights, x_downsampled
def downsample(x: Tensor, indexes: Tensor) -> Tensor:
"""
Downsamples x via indexing with the indexes obtained from the
forward() function.
Args:
x: tensor of shape (seq_len, batch_size, num_channels)
indexes: integer indexes of shape (batch_size, seq_len_reduced), with elements
0 <= indexes < seq_len.
Returns:
x_downsampled, of shape (seq_len_reduced, batch_size, num_channels)
"""
indexes = indexes.t().unsqueeze(-1).expand(-1, -1, x.shape[-1])
# indexes now: (seq_len_reduced, batch_size, num_channels)
ans = torch.gather(x, dim=0, index=indexes)
if __name__ == __main__:
# temp, for testing
x_reconstructed = upsample(x, ans, indexes)
assert torch.allclose(x, x_reconstructed)
return ans
def downsample_pos_emb(pos_emb: Tensor, indexes: Tensor) -> Tensor:
"""
Downsample positional embedding tensor with the provided indexes.
Args:
pos_emb: (batch_size, seq_len, seq_len, pos_dim)
interpreted as (batch_size, tgt_seq_len, src_seq_len, pos_dim).
indexes: (batch_size, seq_len_reduced), containing integer elements
0 <= indexes < seq_len.
Returns:
downsampled_pos_len: (batch_size, seq_len_reduced, seq_len_reduced, pos_dim)
"""
(batch_size, seq_len_reduced) = indexes.shape
(_, _, seq_len, pos_dim) = pos_emb.shape
tgt_indexes = indexes.reshape(batch_size, seq_len_reduced, 1, 1).expand(
batch_size, seq_len_reduced, seq_len, pos_dim)
pos_emb = torch.gather(pos_emb, dim=1, index=tgt_indexes)
# now pos_emb: (batch_size, seq_len_reduced, seq_len, pos_dim)
src_indexes = indexes.reshape(batch_size, 1, seq_len_reduced, 1).expand(
batch_size, seq_len_reduced, seq_len_reduced, pos_dim)
pos_emb = torch.gather(pos_emb, dim=2, index=src_indexes)
# now pos_emb: (batch_size, seq_len_reduced, seq_len_reduced, pos_dim)
return pos_emb
def downsample_attn_offset(attn_offset: Tensor,
indexes: Tensor,
weights: Tensor,
eps: float = 1.0e-05) -> Tensor:
"""
Downsamples attn_offset and also modifies it to account for the weights in `weights`.
Args:
attn_offset: a Tensor of shape (1 or batch_size, seq_len, seq_len), interpreted as
(1 or batch_size, tgt_seq_len, src_seq_len)
indexes: a Tensor of shape (batch_size, reduced_seq_len) containing elements
0 <= indexes < seq_len.
weights: a Tensor of shape (batch_size, reduced_seq_len) containing weights
between 0 and 1; most will be 1.
Returns:
attn_offset_downsampled, a Tensor of shape (batch_size, reduced_seq_len, reduced_seq_len)
"""
(batch_size, seq_len_reduced) = indexes.shape
seq_len = attn_offset.shape[-1]
assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len)
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
attn_offset = attn_offset.gather(dim=1, src=indices.unsqueeze(-1).expand(
batch_size, seq_len_reduced, seq_len))
attn_offset = attn_offset.gather(dim=2, src=indices.unsqueeze(1).expand(
batch_size, seq_len_reduced, seq_len_reduced))
# unsqueeze at position 1 so the extra cost relates to the source position.
attn_offset = attn_offset + weights.clamp(min=eps).log().unsqueeze(1)
return attn_offst
def upsample(x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor:
"""
Upsamples, reversing the downsample() operation and filling in
any not-chosen frames with their original value before downsampling
(or with whatever x_orig contains).
Args:
x_orig: (seq_len, batch_size, num_channels)
x: (seq_len_reduced, batch_size, num_channels)
indexes: (batch_size, seq_len_reduced), contains original frame indexes
Downsamples x via indexing with the indexes obtained from the
forward() function.
Args:
x: tensor of shape (seq_len, batch_size, indexes)
indexes: integer indexes of shape (batch_size, seq_len_reduced), with elements
0 <= indexes < seq_len.
"""
(seq_len, batch_size, num_channels) = x_orig.shape
not_kept = torch.ones(batch_size, seq_len, dtype=torch.bool,
device=x.device)
not_kept.scatter_(src=False, dim=1, index=indexes)
indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
# indexes now: (seq_len_reduced, batch_size, num_channels)
ans = torch.zeros_like(x_orig)
ans.scatter_(x, dim=0, index=indexes)
# add in x_orig in the frames that were not originally kept.
return ans + x_orig * not_kept.t().unsqueeze(-1)
class DownsampledSubformer2Encoder(nn.Module): class DownsampledSubformer2Encoder(nn.Module):
r""" """
DownsampledSubformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, DownsampledSubformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
after convolutional downsampling, and then upsampled again at the output, and combined after convolutional downsampling, and then upsampled again at the output, and combined
with the origin input, so that the output has the same shape as the input. with the origin input, so that the output has the same shape as the input.
@ -818,18 +1039,18 @@ class DownsampledSubformer2Encoder(nn.Module):
dropout: FloatLike): dropout: FloatLike):
super(DownsampledSubformer2Encoder, self).__init__() super(DownsampledSubformer2Encoder, self).__init__()
self.downsample_factor = downsample self.downsample_factor = downsample
self.downsample = SimpleDownsample(dim, self.downsampler = LearnedDownsamplingModule(dim,
downsample, dropout) downsample)
self.encoder = encoder self.encoder = encoder
self.upsample = SimpleUpsample(dim, downsample)
self.out_combiner = BypassModule(dim, straight_through_rate=0.025) self.out_combiner = BypassModule(dim, straight_through_rate=0.025)
def forward(self, def forward(self,
src: Tensor, src: Tensor,
chunk_size: int = -1, pos_emb: Tensor,
feature_mask: Union[Tensor, float] = 1.0, feature_mask: Union[Tensor, float] = 1.0,
attn_mask: Optional[Tensor] = None, attn_offset: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None, memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
@ -838,11 +1059,12 @@ class DownsampledSubformer2Encoder(nn.Module):
Args: Args:
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
pos_emb: the positional embedding, of shape (batch_size, seq_len, seq_len, pos_dim)
feature_mask: something that broadcasts with src, that we'll multiply `src` feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), attn_offset: the attention offset, added to scores for attention of shape
(batch_size, seq_len, seq_len) or (seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
True means masked position. May be None.
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None. masked position. May be None.
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
@ -853,16 +1075,20 @@ class DownsampledSubformer2Encoder(nn.Module):
Returns: a Tensor with the same shape as src. Returns: a Tensor with the same shape as src.
""" """
src_orig = src src_orig = src
src = self.downsample(src) indexes, weights, src = self.downsampler(src)
ds = self.downsample_factor
if attn_mask is not None: pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
attn_mask = attn_mask[::ds,::ds]
attn_offset = self.downsample.downsample_attn_offset(attn_offset,
indexes,
weights.clamp(min=1.0e-05))
src = self.encoder( src = self.encoder(
src, src,
chunk_size=chunk_size // ds, os_emb,
feature_mask=feature_mask, feature_mask=feature_mask,
attn_mask=attn_mask, attn_offset=attn_offset,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
@ -875,77 +1101,6 @@ class DownsampledSubformer2Encoder(nn.Module):
class SimpleDownsample(torch.nn.Module):
"""
Does downsampling with attention, by weighted sum, and a projection..
"""
def __init__(self,
channels: int,
downsample: int,
dropout: FloatLike):
super(SimpleDownsample, self).__init__()
self.bias = nn.Parameter(torch.zeros(downsample))
self.name = None # will be set from training code
self.dropout = copy.deepcopy(dropout)
self.downsample = downsample
def forward(self,
src: Tensor) -> Tensor:
"""
x: (seq_len, batch_size, in_channels)
Returns a tensor of shape
( (seq_len+downsample-1)//downsample, batch_size, channels)
"""
(seq_len, batch_size, in_channels) = src.shape
ds = self.downsample
d_seq_len = (seq_len + ds - 1) // ds
# Pad to an exact multiple of self.downsample
if seq_len != d_seq_len * ds:
# right-pad src, repeating the last element.
pad = d_seq_len * ds - seq_len
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
src = torch.cat((src, src_extra), dim=0)
assert src.shape[0] == d_seq_len * ds
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
weights = self.bias.softmax(dim=0)
# weights: (downsample, 1, 1)
weights = weights.unsqueeze(-1).unsqueeze(-1)
# ans1 is the first `in_channels` channels of the output
ans = (src * weights).sum(dim=1)
return ans
class SimpleUpsample(torch.nn.Module):
"""
A very simple form of upsampling that mostly just repeats the input, but
also adds a position-specific bias.
"""
def __init__(self,
num_channels: int,
upsample: int):
super(SimpleUpsample, self).__init__()
self.upsample = upsample
def forward(self,
src: Tensor) -> Tensor:
"""
x: (seq_len, batch_size, num_channels)
Returns a tensor of shape
( (seq_len*upsample), batch_size, num_channels)
"""
upsample = self.upsample
(seq_len, batch_size, num_channels) = src.shape
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
src = src.reshape(seq_len * upsample, batch_size, num_channels)
return src
class CompactRelPositionalEncoding(torch.nn.Module): class CompactRelPositionalEncoding(torch.nn.Module):
@ -967,17 +1122,19 @@ class CompactRelPositionalEncoding(torch.nn.Module):
Args: Args:
embed_dim: Embedding dimension. embed_dim: Temporary embedding dimension used inside this module
dropout_rate: Dropout rate. dropout_rate: Dropout rate.
max_len: Maximum input length: just a heuristic for initialization. max_len: Maximum input length: just a heuristic for initialization.
length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
less weight to small differences of offset near the origin. less weight to small differences of offset near the origin.
pos_dim: dimension at the output of this module.
""" """
def __init__( def __init__(
self, embed_dim: int, self, embed_dim: int,
dropout_rate: FloatLike, dropout_rate: FloatLike,
max_len: int = 1000, max_len: int = 1000,
length_factor: float = 1.0, length_factor: float = 1.0,
pos_dim: int = 4,
) -> None: ) -> None:
"""Construct a CompactRelPositionalEncoding object.""" """Construct a CompactRelPositionalEncoding object."""
super(CompactRelPositionalEncoding, self).__init__() super(CompactRelPositionalEncoding, self).__init__()
@ -989,6 +1146,11 @@ class CompactRelPositionalEncoding(torch.nn.Module):
self.length_factor = length_factor self.length_factor = length_factor
self.extend_pe(torch.tensor(0.0).expand(max_len)) self.extend_pe(torch.tensor(0.0).expand(max_len))
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim,
pos_dim,
bias=False,
initial_scale=0.05)
def extend_pe(self, x: Tensor) -> None: def extend_pe(self, x: Tensor) -> None:
@ -1046,27 +1208,45 @@ class CompactRelPositionalEncoding(torch.nn.Module):
"""Create positional encoding. """Create positional encoding.
Args: Args:
x (torch.Tensor): Input tensor (time, batch, `*`). x (torch.Tensor): Input tensor (time, batch, num_channels_in)
Returns: Returns:
positional embedding, of shape (1, 2*time-1, `*`). positional embedding, of shape (1, 2*time-1, pos_dim).
""" """
self.extend_pe(x) self.extend_pe(x)
seq_len = x.size(0)
pos_emb = self.pe[ pos_emb = self.pe[
self.pe.size(0) // 2 self.pe.size(0) // 2
- x.size(0) - seq_len,
+ 1 : self.pe.size(0) // 2 # noqa E203 + 1 : self.pe.size(0) // 2 # noqa E203
+ x.size(0), + seq_len,
: :
] ]
pos_emb = pos_emb.unsqueeze(0) pos_emb = pos_emb.unsqueeze(0)
return self.dropout(pos_emb) pos_emb = self.dropout(pos_emb)
pos_emb = self.linear_pos(pos_emb)
# currenly pos_emb: (1, 2*seq_len-1, pos_dim)
pos_dim = pos_emb.shape[-1]
batch_size = x.size(1)
(_, seq_stride, channel_stride) = pos_emb.stride()
# it doesn't really matter which one we make positive and which negative here, it
# would just flip the meaning of the embedding.
pos_emb = pos_emb.as_strided((batch_size, seq_len, seq_len, pos_dim),
(0, -seq_stride, seq_stride, channel_stride),
storage_offset=seq_stride * (seqs_len - 1))
return pos_emb # (batch_size, seq_len, seq_len, pos_dim)
class RelPositionMultiheadAttentionWeights(nn.Module): class RelPositionMultiheadAttentionWeights(nn.Module):
r"""Module that computes multi-head attention weights with relative position encoding. r"""Module that computes multi-head attention weights with relative position encoding;
in this version, the positions for each frame are passed in (in order to support
Various other modules consume the resulting attention weights: see, for example, the Various other modules consume the resulting attention weights: see, for example, the
SimpleAttention module which allows you to compute conventional attention. SimpleAttention module which allows you to compute conventional attention.
@ -1076,10 +1256,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
Args: Args:
embed_dim: number of channels at the input to this module, e.g. 256 embed_dim: number of channels at the input to this module, e.g. 256
pos_dim: dimension of the positional encoding vectors, e.g. 128.
num_heads: number of heads to compute weights for, e.g. 8 num_heads: number of heads to compute weights for, e.g. 8
query_head_dim: dimension of the query (and key), per head. e.g. 24. query_head_dim: dimension of the query (and key), per head. e.g. 24.
pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. pos_dim: dimension of the projected positional encoding, e.g. 4.
dropout: dropout probability for attn_output_weights. Default: 0.0. dropout: dropout probability for attn_output_weights. Default: 0.0.
pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
any given call to forward(), in training time. any given call to forward(), in training time.
@ -1088,10 +1267,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
def __init__( def __init__(
self, self,
embed_dim: int, embed_dim: int,
pos_dim: int,
num_heads: int, num_heads: int,
query_head_dim: int, query_head_dim: int,
pos_head_dim: int, pos_dim: int,
dropout: float = 0.0, dropout: float = 0.0,
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5),
(4000.0, 0.0)) (4000.0, 0.0))
@ -1100,13 +1278,13 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.query_head_dim = query_head_dim self.query_head_dim = query_head_dim
self.pos_head_dim = pos_head_dim self.pos_dim = pos_dim
self.dropout = dropout self.dropout = dropout
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
self.name = None # will be overwritten in training code; for diagnostics. self.name = None # will be overwritten in training code; for diagnostics.
key_head_dim = query_head_dim key_head_dim = query_head_dim
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads
# the initial_scale is supposed to take over the "scaling" factor of # the initial_scale is supposed to take over the "scaling" factor of
# head_dim ** -0.5 that has been used in previous forms of attention, # head_dim ** -0.5 that has been used in previous forms of attention,
@ -1138,13 +1316,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
prob=0.025) prob=0.025)
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(pos_dim,
num_heads * pos_head_dim,
bias=False,
initial_scale=0.05)
# the following are for diagnosics only, see --print-diagnostics option # the following are for diagnosics only, see --print-diagnostics option
self.copy_pos_query = Identity() self.copy_pos_query = Identity()
self.copy_query = Identity() self.copy_query = Identity()
@ -1154,27 +1325,30 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self, self,
x: Tensor, x: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
chunk_size: int = -1, attn_offset: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None, pos_emb: Tensor,
attn_mask: Optional[Tensor] = None, quadratic_pos_weight: Tensor,
) -> Tensor: ) -> Tensor:
r""" r"""
Args: Args:
x: input of shape (seq_len, batch_size, embed_dim) x: input of shape (seq_len, batch_size, embed_dim)
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 2, pos_dim) pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 2, pos_dim)
chunk_size
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that attn_offset: a Tensor of shape broadcasting with (batch_size, seq_len, seq_len),
are True in this mask will be ignored as sources in the attention weighting. interpreted as (batch_size, tgt_seq_len, src_seq_len), if provided this
attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), contains values (probably <= 0) to be added to the logprobs of the attention;
interpreted as ([batch_size,] tgt_seq_len, src_seq_len) this may combine the log of 'weights' of ChooseDownsamplingModule with
saying which positions are allowed to attend to which other positions. any attn_mask that enforces causality.
pos_emb: a Tensor of shape broadcasting with (batch_size, seq_len, seq_len, pos_dim)
(e.g. pos_dim=4), encoding relative positions.
Returns: Returns:
a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
""" """
x = self.in_proj(x) x = self.in_proj(x)
query_head_dim = self.query_head_dim query_head_dim = self.query_head_dim
pos_head_dim = self.pos_head_dim pos_dim = self.pos_dim
num_heads = self.num_heads num_heads = self.num_heads
seq_len, batch_size, _ = x.shape seq_len, batch_size, _ = x.shape
@ -1185,7 +1359,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
k = x[...,query_dim:2*query_dim] k = x[...,query_dim:2*query_dim]
# p is the position-encoding query # p is the position-encoding query
p = x[...,2*query_dim:] p = x[...,2*query_dim:]
assert p.shape[-1] == num_heads * pos_head_dim assert p.shape[-1] == num_heads * pos_dim
q = self.copy_query(q) # for diagnostics only, does nothing. q = self.copy_query(q) # for diagnostics only, does nothing.