mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
More progress on subformer
This commit is contained in:
parent
5c470fe397
commit
f740282a1a
@ -93,12 +93,12 @@ class Subformer2(EncoderInterface):
|
||||
num_encoder_layers: Union[int, Tuple[int]] = 4,
|
||||
encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
|
||||
query_head_dim: Union[int, Tuple[int]] = 24,
|
||||
pos_head_dim: Union[int, Tuple[int]] = 4,
|
||||
value_head_dim: Union[int, Tuple[int]] = 12,
|
||||
num_heads: Union[int, Tuple[int]] = 8,
|
||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
||||
memory_dim: int = -1,
|
||||
pos_dim: int = 192,
|
||||
pos_emb_dim: int = 192,
|
||||
pos_dim: int = 4,
|
||||
dropout: FloatLike = None, # see code below for default
|
||||
warmup_batches: float = 4000.0,
|
||||
causal: bool = False,
|
||||
@ -127,7 +127,6 @@ class Subformer2(EncoderInterface):
|
||||
num_encoder_layers = _to_tuple(num_encoder_layers)
|
||||
query_head_dim = _to_tuple(query_head_dim)
|
||||
value_head_dim = _to_tuple(value_head_dim)
|
||||
pos_head_dim = _to_tuple(pos_head_dim)
|
||||
num_heads = _to_tuple(num_heads)
|
||||
feedforward_dim = _to_tuple(feedforward_dim)
|
||||
|
||||
@ -145,7 +144,6 @@ class Subformer2(EncoderInterface):
|
||||
pos_dim=pos_dim,
|
||||
num_heads=num_heads[i],
|
||||
query_head_dim=query_head_dim[i],
|
||||
pos_head_dim=pos_head_dim[i],
|
||||
value_head_dim=value_head_dim[i],
|
||||
feedforward_dim=feedforward_dim[i],
|
||||
memory_dim=memory_dim,
|
||||
@ -175,6 +173,9 @@ class Subformer2(EncoderInterface):
|
||||
|
||||
encoders.append(encoder)
|
||||
|
||||
self.encoder_pos = CompactRelPositionalEncoding(pos_emb_dim, pos_dim, dropout_rate=0.15,
|
||||
length_factor=1.0)
|
||||
|
||||
self.encoders = nn.ModuleList(encoders)
|
||||
|
||||
self.downsample_output = SimpleDownsample(max(encoder_dim),
|
||||
@ -272,7 +273,7 @@ class Subformer2(EncoderInterface):
|
||||
outputs = []
|
||||
feature_masks = self.get_feature_masks(x)
|
||||
|
||||
attn_mask = self._get_attn_mask(x)
|
||||
attn_offset = self._get_attn_offset(x)
|
||||
|
||||
if self.training and memory is not None:
|
||||
batch_size = x.shape[1]
|
||||
@ -282,16 +283,19 @@ class Subformer2(EncoderInterface):
|
||||
memory = memory * (torch.rand(batch_size, 1, device=memory.device) >
|
||||
memory_dropout_rate)
|
||||
|
||||
pos_emb = self.encoder_pos(x)
|
||||
|
||||
for i, module in enumerate(self.encoders):
|
||||
ds = self.downsampling_factor[i]
|
||||
x = convert_num_channels(x, self.encoder_dim[i])
|
||||
|
||||
x = module(x,
|
||||
pos_emb,
|
||||
chunk_size=chunk_size,
|
||||
feature_mask=feature_masks[i],
|
||||
src_key_padding_mask=(None if src_key_padding_mask is None
|
||||
else src_key_padding_mask[...,::ds]),
|
||||
attn_mask=attn_mask,
|
||||
attn_offset=attn_offset,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
@ -324,11 +328,11 @@ class Subformer2(EncoderInterface):
|
||||
|
||||
return x, lengths
|
||||
|
||||
def _get_attn_mask(self, x: Tensor) -> Optional[Tensor]:
|
||||
def _get_attn_offset(self, x: Tensor) -> Optional[Tensor]:
|
||||
"""
|
||||
Return None if not self.causal is false else return attention mask of shape
|
||||
(seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True
|
||||
means a masked position.
|
||||
Return attention offset of shape (1, seq_len, seq_len), interpreted as (tgt_seq_len,
|
||||
src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros.
|
||||
|
||||
Args:
|
||||
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
|
||||
chunk_size: chunk size, must divide
|
||||
@ -345,7 +349,11 @@ class Subformer2(EncoderInterface):
|
||||
|
||||
attn_mask = (src_c > tgt_c)
|
||||
|
||||
return attn_mask
|
||||
ans = torch.zeros(1, seq_len, seq_len, device=x.device)
|
||||
|
||||
ans.masked_fill(attn_mask, float('-inf'))
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
|
||||
@ -379,7 +387,7 @@ class Subformer2EncoderLayer(nn.Module):
|
||||
pos_dim: int,
|
||||
num_heads: int,
|
||||
query_head_dim: int,
|
||||
pos_head_dim: int,
|
||||
pos_dim: int,
|
||||
value_head_dim: int,
|
||||
feedforward_dim: int,
|
||||
dropout: FloatLike = 0.1,
|
||||
@ -416,8 +424,8 @@ class Subformer2EncoderLayer(nn.Module):
|
||||
self.const_attention_rate = copy.deepcopy(const_attention_rate)
|
||||
|
||||
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
|
||||
embed_dim, pos_dim=pos_dim, num_heads=num_heads,
|
||||
query_head_dim=query_head_dim, pos_head_dim=pos_head_dim,
|
||||
embed_dim, num_heads=num_heads,
|
||||
query_head_dim=query_head_dim, pos_dim=pos_dim,
|
||||
dropout=0.0,
|
||||
)
|
||||
|
||||
@ -552,7 +560,7 @@ class Subformer2EncoderLayer(nn.Module):
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
chunk_size: int = -1,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
attn_offset: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
memory: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
@ -565,10 +573,9 @@ class Subformer2EncoderLayer(nn.Module):
|
||||
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
||||
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
||||
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
||||
True means masked position. May be None.
|
||||
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
||||
attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len),
|
||||
interpreted as (batch_size, tgt_seq_len, src_seq_len). -inf for masked position.
|
||||
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
||||
masked position. May be None.
|
||||
|
||||
Returns:
|
||||
@ -583,7 +590,7 @@ class Subformer2EncoderLayer(nn.Module):
|
||||
attn_weights = self.self_attn_weights(
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=attn_mask,
|
||||
attn_offset=attn_offset,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
@ -675,9 +682,6 @@ class Subformer2Encoder(nn.Module):
|
||||
final_layerdrop_rate: float = 0.05,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.encoder_pos = CompactRelPositionalEncoding(pos_dim, dropout_rate=0.15,
|
||||
length_factor=1.0)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
)
|
||||
@ -699,7 +703,7 @@ class Subformer2Encoder(nn.Module):
|
||||
src: Tensor,
|
||||
chunk_size: int = -1,
|
||||
feature_mask: Union[Tensor, float] = 1.0,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
attn_offset: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
memory: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
@ -711,9 +715,9 @@ class Subformer2Encoder(nn.Module):
|
||||
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
||||
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
||||
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
||||
True means masked position. May be None.
|
||||
attn_offset: the attention offset (does masking and related tasks), of shape
|
||||
broadcasting with (batch_size, seq_len, seq_len),
|
||||
interpreted as (batch_size, tgt_seq_len, src_seq_len).
|
||||
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
||||
masked position. May be None.
|
||||
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
||||
@ -735,7 +739,7 @@ class Subformer2Encoder(nn.Module):
|
||||
output,
|
||||
pos_emb,
|
||||
chunk_size=chunk_size,
|
||||
attn_mask=attn_mask,
|
||||
attn_offset=attn_offset,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
@ -803,10 +807,227 @@ class BypassModule(nn.Module):
|
||||
return src_orig + (src - src_orig) * bypass_scale
|
||||
|
||||
|
||||
class LearnedDownsamplingModule(nn.Module):
|
||||
"""
|
||||
Module that allows you to choose which frames to keep for transformer-type
|
||||
modules. Effectively downsampling, but not necessarily "evenly"- you just
|
||||
keep some proportion of frames determined by the embedding.
|
||||
|
||||
Args:
|
||||
embed_dim: embedding dimension
|
||||
downsampling_factor: factor to downsample by, e.g. 2 or 4. There is no
|
||||
fundamental reason why this has to be an integer, but we make it so
|
||||
anyway.
|
||||
intermediate_rate: the proportion of the downsampled values that have
|
||||
"intermediate weights"- between kept and downsampled. The user is
|
||||
supposed to use these in such a way that if the weight we return is
|
||||
0.0, it's equivalent to not using this frame at all.
|
||||
"""
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
downsampling_factor: int,
|
||||
intermediate_rate: FloatLike = 0.2):
|
||||
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
|
||||
# score_balancer is just to keep the magnitudes of the scores in
|
||||
# a fixed range and keep them balanced around zero, to stop
|
||||
# these drifting around.
|
||||
self.score_balancer = Balancer(1, channel_dim=-1,
|
||||
min_positive=0.4, max_positive=0.6
|
||||
min_abs=1.0, max_abs=1.2,
|
||||
prob=0.025)
|
||||
|
||||
self.downsampling_factor = downsampling_factor
|
||||
self.intermediate_rate = copy.deepcopy(intermediate_rate)
|
||||
|
||||
|
||||
def forward(self,
|
||||
x: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x: a Tensor of shape (seq_len, batch_size, embed_dim)
|
||||
|
||||
Returns: (frame_indexes, weights, kept)
|
||||
|
||||
frame_indexes: a Tensor of integer type, of shape (batch_size, reduced_seq_len)
|
||||
where reduced_seq_len = (seq_len + d - 1) // d. It contains elements
|
||||
0 <= frame_indees < seq_len, in sorted (increasing) order
|
||||
|
||||
weights: a Tensor of shape (batch_size, reduced_seq_len),
|
||||
corresponding to the kept frames; these will be between 0 and 1, but
|
||||
mostly exactly 1.
|
||||
"""
|
||||
(seq_len, batch_size, _)
|
||||
scores = self.to_scores(x) # (seq_len, batch_size, 1)
|
||||
scores = self.score_balancer(scores)
|
||||
|
||||
scores = scores.squeeze(-1).t() # (batch_size, seq_len)
|
||||
|
||||
# indexes, sscores: (batch_size, seq_len)
|
||||
indexes, sscores = scores.sort(dim=-1, descending=True)
|
||||
|
||||
d = self.downsampling_factor
|
||||
seq_len_reduced = (seq_len + d - 1) // d
|
||||
|
||||
# TODO: if seq_len / downsampling_factor <= 2, do something special.
|
||||
|
||||
# 'right' is the rightmost of the 2 limits; we want the scores indexed
|
||||
# 'upper' to be mapped to around 0.0
|
||||
right = seq_len_reduced
|
||||
# we want scores around 'left' to be mapped to around 1.0.
|
||||
left = int(seq_len_reduced * (1.0 - self.intermediate_rate))
|
||||
|
||||
# 'collar' determines the range of positions in the sorted list that we use to
|
||||
# compute the average. We could let collar be 0.0, which would more exactly
|
||||
# accomplish what we want; but we don't, because this would cause too-noisy
|
||||
# gradients, with too much gradient going to one frame.
|
||||
collar = max(1, int(seq_len_reduced * 0.5 * self.intermediate_rate))
|
||||
|
||||
# right_avg: shape (batch_size,), this is to be mapped to 0.0
|
||||
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1)
|
||||
|
||||
# left_avg: shape (batch_size,), this is to be mapped to 1.0
|
||||
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1)
|
||||
|
||||
# the + 0.001 is to avoid possible division by zero in case of ties.
|
||||
weights = (sscores - right_avg) / (left_avg - right_avg + 0.001)
|
||||
weights = weights.clamp(min=0.0, max=1.0)
|
||||
|
||||
indexes = indexes[:, :seq_len_reduced]
|
||||
weights = weights[:, :seq_len_reduced]
|
||||
|
||||
# re-sort the indexes we kept, on index value, so that
|
||||
# masking for causal models will be in the correct order.
|
||||
|
||||
indexes, reorder = indexes.sort(dim=-1)
|
||||
weights = torch.gather(weights, dim=-1, index=reorder)
|
||||
|
||||
x_downsampled = downsample(indexes, x)
|
||||
return indexes, weights, x_downsampled
|
||||
|
||||
|
||||
def downsample(x: Tensor, indexes: Tensor) -> Tensor:
|
||||
"""
|
||||
Downsamples x via indexing with the indexes obtained from the
|
||||
forward() function.
|
||||
|
||||
Args:
|
||||
x: tensor of shape (seq_len, batch_size, num_channels)
|
||||
indexes: integer indexes of shape (batch_size, seq_len_reduced), with elements
|
||||
0 <= indexes < seq_len.
|
||||
Returns:
|
||||
x_downsampled, of shape (seq_len_reduced, batch_size, num_channels)
|
||||
"""
|
||||
indexes = indexes.t().unsqueeze(-1).expand(-1, -1, x.shape[-1])
|
||||
# indexes now: (seq_len_reduced, batch_size, num_channels)
|
||||
ans = torch.gather(x, dim=0, index=indexes)
|
||||
|
||||
if __name__ == __main__:
|
||||
# temp, for testing
|
||||
x_reconstructed = upsample(x, ans, indexes)
|
||||
assert torch.allclose(x, x_reconstructed)
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
def downsample_pos_emb(pos_emb: Tensor, indexes: Tensor) -> Tensor:
|
||||
"""
|
||||
Downsample positional embedding tensor with the provided indexes.
|
||||
Args:
|
||||
pos_emb: (batch_size, seq_len, seq_len, pos_dim)
|
||||
interpreted as (batch_size, tgt_seq_len, src_seq_len, pos_dim).
|
||||
indexes: (batch_size, seq_len_reduced), containing integer elements
|
||||
0 <= indexes < seq_len.
|
||||
Returns:
|
||||
downsampled_pos_len: (batch_size, seq_len_reduced, seq_len_reduced, pos_dim)
|
||||
"""
|
||||
|
||||
(batch_size, seq_len_reduced) = indexes.shape
|
||||
(_, _, seq_len, pos_dim) = pos_emb.shape
|
||||
|
||||
tgt_indexes = indexes.reshape(batch_size, seq_len_reduced, 1, 1).expand(
|
||||
batch_size, seq_len_reduced, seq_len, pos_dim)
|
||||
|
||||
pos_emb = torch.gather(pos_emb, dim=1, index=tgt_indexes)
|
||||
# now pos_emb: (batch_size, seq_len_reduced, seq_len, pos_dim)
|
||||
|
||||
src_indexes = indexes.reshape(batch_size, 1, seq_len_reduced, 1).expand(
|
||||
batch_size, seq_len_reduced, seq_len_reduced, pos_dim)
|
||||
|
||||
pos_emb = torch.gather(pos_emb, dim=2, index=src_indexes)
|
||||
# now pos_emb: (batch_size, seq_len_reduced, seq_len_reduced, pos_dim)
|
||||
return pos_emb
|
||||
|
||||
|
||||
def downsample_attn_offset(attn_offset: Tensor,
|
||||
indexes: Tensor,
|
||||
weights: Tensor,
|
||||
eps: float = 1.0e-05) -> Tensor:
|
||||
"""
|
||||
Downsamples attn_offset and also modifies it to account for the weights in `weights`.
|
||||
Args:
|
||||
attn_offset: a Tensor of shape (1 or batch_size, seq_len, seq_len), interpreted as
|
||||
(1 or batch_size, tgt_seq_len, src_seq_len)
|
||||
indexes: a Tensor of shape (batch_size, reduced_seq_len) containing elements
|
||||
0 <= indexes < seq_len.
|
||||
weights: a Tensor of shape (batch_size, reduced_seq_len) containing weights
|
||||
between 0 and 1; most will be 1.
|
||||
Returns:
|
||||
attn_offset_downsampled, a Tensor of shape (batch_size, reduced_seq_len, reduced_seq_len)
|
||||
"""
|
||||
(batch_size, seq_len_reduced) = indexes.shape
|
||||
seq_len = attn_offset.shape[-1]
|
||||
assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len)
|
||||
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
|
||||
|
||||
|
||||
attn_offset = attn_offset.gather(dim=1, src=indices.unsqueeze(-1).expand(
|
||||
batch_size, seq_len_reduced, seq_len))
|
||||
attn_offset = attn_offset.gather(dim=2, src=indices.unsqueeze(1).expand(
|
||||
batch_size, seq_len_reduced, seq_len_reduced))
|
||||
# unsqueeze at position 1 so the extra cost relates to the source position.
|
||||
attn_offset = attn_offset + weights.clamp(min=eps).log().unsqueeze(1)
|
||||
|
||||
return attn_offst
|
||||
|
||||
|
||||
def upsample(x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor:
|
||||
"""
|
||||
Upsamples, reversing the downsample() operation and filling in
|
||||
any not-chosen frames with their original value before downsampling
|
||||
(or with whatever x_orig contains).
|
||||
|
||||
Args:
|
||||
x_orig: (seq_len, batch_size, num_channels)
|
||||
x: (seq_len_reduced, batch_size, num_channels)
|
||||
indexes: (batch_size, seq_len_reduced), contains original frame indexes
|
||||
|
||||
Downsamples x via indexing with the indexes obtained from the
|
||||
forward() function.
|
||||
|
||||
Args:
|
||||
x: tensor of shape (seq_len, batch_size, indexes)
|
||||
indexes: integer indexes of shape (batch_size, seq_len_reduced), with elements
|
||||
0 <= indexes < seq_len.
|
||||
"""
|
||||
(seq_len, batch_size, num_channels) = x_orig.shape
|
||||
|
||||
not_kept = torch.ones(batch_size, seq_len, dtype=torch.bool,
|
||||
device=x.device)
|
||||
not_kept.scatter_(src=False, dim=1, index=indexes)
|
||||
|
||||
indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
|
||||
# indexes now: (seq_len_reduced, batch_size, num_channels)
|
||||
|
||||
ans = torch.zeros_like(x_orig)
|
||||
|
||||
ans.scatter_(x, dim=0, index=indexes)
|
||||
|
||||
# add in x_orig in the frames that were not originally kept.
|
||||
return ans + x_orig * not_kept.t().unsqueeze(-1)
|
||||
|
||||
|
||||
class DownsampledSubformer2Encoder(nn.Module):
|
||||
r"""
|
||||
"""
|
||||
DownsampledSubformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
|
||||
after convolutional downsampling, and then upsampled again at the output, and combined
|
||||
with the origin input, so that the output has the same shape as the input.
|
||||
@ -818,18 +1039,18 @@ class DownsampledSubformer2Encoder(nn.Module):
|
||||
dropout: FloatLike):
|
||||
super(DownsampledSubformer2Encoder, self).__init__()
|
||||
self.downsample_factor = downsample
|
||||
self.downsample = SimpleDownsample(dim,
|
||||
downsample, dropout)
|
||||
self.downsampler = LearnedDownsamplingModule(dim,
|
||||
downsample)
|
||||
self.encoder = encoder
|
||||
self.upsample = SimpleUpsample(dim, downsample)
|
||||
|
||||
self.out_combiner = BypassModule(dim, straight_through_rate=0.025)
|
||||
|
||||
|
||||
def forward(self,
|
||||
src: Tensor,
|
||||
chunk_size: int = -1,
|
||||
pos_emb: Tensor,
|
||||
feature_mask: Union[Tensor, float] = 1.0,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
attn_offset: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
memory: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
@ -838,11 +1059,12 @@ class DownsampledSubformer2Encoder(nn.Module):
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
||||
pos_emb: the positional embedding, of shape (batch_size, seq_len, seq_len, pos_dim)
|
||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
||||
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
||||
attn_offset: the attention offset, added to scores for attention of shape
|
||||
(batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
||||
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
||||
True means masked position. May be None.
|
||||
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
||||
masked position. May be None.
|
||||
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
||||
@ -853,16 +1075,20 @@ class DownsampledSubformer2Encoder(nn.Module):
|
||||
Returns: a Tensor with the same shape as src.
|
||||
"""
|
||||
src_orig = src
|
||||
src = self.downsample(src)
|
||||
ds = self.downsample_factor
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask[::ds,::ds]
|
||||
indexes, weights, src = self.downsampler(src)
|
||||
|
||||
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
|
||||
|
||||
attn_offset = self.downsample.downsample_attn_offset(attn_offset,
|
||||
indexes,
|
||||
weights.clamp(min=1.0e-05))
|
||||
|
||||
|
||||
src = self.encoder(
|
||||
src,
|
||||
chunk_size=chunk_size // ds,
|
||||
os_emb,
|
||||
feature_mask=feature_mask,
|
||||
attn_mask=attn_mask,
|
||||
attn_offset=attn_offset,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
@ -875,77 +1101,6 @@ class DownsampledSubformer2Encoder(nn.Module):
|
||||
|
||||
|
||||
|
||||
class SimpleDownsample(torch.nn.Module):
|
||||
"""
|
||||
Does downsampling with attention, by weighted sum, and a projection..
|
||||
"""
|
||||
def __init__(self,
|
||||
channels: int,
|
||||
downsample: int,
|
||||
dropout: FloatLike):
|
||||
super(SimpleDownsample, self).__init__()
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(downsample))
|
||||
|
||||
self.name = None # will be set from training code
|
||||
self.dropout = copy.deepcopy(dropout)
|
||||
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self,
|
||||
src: Tensor) -> Tensor:
|
||||
"""
|
||||
x: (seq_len, batch_size, in_channels)
|
||||
Returns a tensor of shape
|
||||
( (seq_len+downsample-1)//downsample, batch_size, channels)
|
||||
"""
|
||||
(seq_len, batch_size, in_channels) = src.shape
|
||||
ds = self.downsample
|
||||
d_seq_len = (seq_len + ds - 1) // ds
|
||||
|
||||
# Pad to an exact multiple of self.downsample
|
||||
if seq_len != d_seq_len * ds:
|
||||
# right-pad src, repeating the last element.
|
||||
pad = d_seq_len * ds - seq_len
|
||||
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
|
||||
src = torch.cat((src, src_extra), dim=0)
|
||||
assert src.shape[0] == d_seq_len * ds
|
||||
|
||||
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||
|
||||
weights = self.bias.softmax(dim=0)
|
||||
# weights: (downsample, 1, 1)
|
||||
weights = weights.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# ans1 is the first `in_channels` channels of the output
|
||||
ans = (src * weights).sum(dim=1)
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
class SimpleUpsample(torch.nn.Module):
|
||||
"""
|
||||
A very simple form of upsampling that mostly just repeats the input, but
|
||||
also adds a position-specific bias.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
upsample: int):
|
||||
super(SimpleUpsample, self).__init__()
|
||||
self.upsample = upsample
|
||||
|
||||
def forward(self,
|
||||
src: Tensor) -> Tensor:
|
||||
"""
|
||||
x: (seq_len, batch_size, num_channels)
|
||||
Returns a tensor of shape
|
||||
( (seq_len*upsample), batch_size, num_channels)
|
||||
"""
|
||||
upsample = self.upsample
|
||||
(seq_len, batch_size, num_channels) = src.shape
|
||||
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
|
||||
src = src.reshape(seq_len * upsample, batch_size, num_channels)
|
||||
return src
|
||||
|
||||
|
||||
class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
@ -967,17 +1122,19 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
|
||||
|
||||
Args:
|
||||
embed_dim: Embedding dimension.
|
||||
embed_dim: Temporary embedding dimension used inside this module
|
||||
dropout_rate: Dropout rate.
|
||||
max_len: Maximum input length: just a heuristic for initialization.
|
||||
length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
|
||||
less weight to small differences of offset near the origin.
|
||||
pos_dim: dimension at the output of this module.
|
||||
"""
|
||||
def __init__(
|
||||
self, embed_dim: int,
|
||||
self, embed_dim: int,
|
||||
dropout_rate: FloatLike,
|
||||
max_len: int = 1000,
|
||||
length_factor: float = 1.0,
|
||||
pos_dim: int = 4,
|
||||
) -> None:
|
||||
"""Construct a CompactRelPositionalEncoding object."""
|
||||
super(CompactRelPositionalEncoding, self).__init__()
|
||||
@ -989,6 +1146,11 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
self.length_factor = length_factor
|
||||
self.extend_pe(torch.tensor(0.0).expand(max_len))
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = ScaledLinear(embed_dim,
|
||||
pos_dim,
|
||||
bias=False,
|
||||
initial_scale=0.05)
|
||||
|
||||
|
||||
def extend_pe(self, x: Tensor) -> None:
|
||||
@ -1046,27 +1208,45 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
"""Create positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (time, batch, `*`).
|
||||
x (torch.Tensor): Input tensor (time, batch, num_channels_in)
|
||||
|
||||
Returns:
|
||||
positional embedding, of shape (1, 2*time-1, `*`).
|
||||
positional embedding, of shape (1, 2*time-1, pos_dim).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
seq_len = x.size(0)
|
||||
pos_emb = self.pe[
|
||||
self.pe.size(0) // 2
|
||||
- x.size(0)
|
||||
- seq_len,
|
||||
+ 1 : self.pe.size(0) // 2 # noqa E203
|
||||
+ x.size(0),
|
||||
+ seq_len,
|
||||
:
|
||||
]
|
||||
pos_emb = pos_emb.unsqueeze(0)
|
||||
return self.dropout(pos_emb)
|
||||
pos_emb = self.dropout(pos_emb)
|
||||
pos_emb = self.linear_pos(pos_emb)
|
||||
|
||||
# currenly pos_emb: (1, 2*seq_len-1, pos_dim)
|
||||
pos_dim = pos_emb.shape[-1]
|
||||
batch_size = x.size(1)
|
||||
(_, seq_stride, channel_stride) = pos_emb.stride()
|
||||
# it doesn't really matter which one we make positive and which negative here, it
|
||||
# would just flip the meaning of the embedding.
|
||||
pos_emb = pos_emb.as_strided((batch_size, seq_len, seq_len, pos_dim),
|
||||
(0, -seq_stride, seq_stride, channel_stride),
|
||||
storage_offset=seq_stride * (seqs_len - 1))
|
||||
|
||||
return pos_emb # (batch_size, seq_len, seq_len, pos_dim)
|
||||
|
||||
|
||||
|
||||
|
||||
class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
r"""Module that computes multi-head attention weights with relative position encoding.
|
||||
r"""Module that computes multi-head attention weights with relative position encoding;
|
||||
in this version, the positions for each frame are passed in (in order to support
|
||||
|
||||
|
||||
Various other modules consume the resulting attention weights: see, for example, the
|
||||
SimpleAttention module which allows you to compute conventional attention.
|
||||
|
||||
@ -1076,22 +1256,20 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
Args:
|
||||
embed_dim: number of channels at the input to this module, e.g. 256
|
||||
pos_dim: dimension of the positional encoding vectors, e.g. 128.
|
||||
num_heads: number of heads to compute weights for, e.g. 8
|
||||
query_head_dim: dimension of the query (and key), per head. e.g. 24.
|
||||
pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
|
||||
pos_dim: dimension of the projected positional encoding, e.g. 4.
|
||||
dropout: dropout probability for attn_output_weights. Default: 0.0.
|
||||
pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
|
||||
any given call to forward(), in training time.
|
||||
pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
|
||||
any given call to forward(), in training time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
pos_dim: int,
|
||||
num_heads: int,
|
||||
query_head_dim: int,
|
||||
pos_head_dim: int,
|
||||
pos_dim: int,
|
||||
dropout: float = 0.0,
|
||||
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5),
|
||||
(4000.0, 0.0))
|
||||
@ -1100,13 +1278,13 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.query_head_dim = query_head_dim
|
||||
self.pos_head_dim = pos_head_dim
|
||||
self.pos_dim = pos_dim
|
||||
self.dropout = dropout
|
||||
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
||||
self.name = None # will be overwritten in training code; for diagnostics.
|
||||
|
||||
key_head_dim = query_head_dim
|
||||
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
|
||||
in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads
|
||||
|
||||
# the initial_scale is supposed to take over the "scaling" factor of
|
||||
# head_dim ** -0.5 that has been used in previous forms of attention,
|
||||
@ -1138,13 +1316,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
prob=0.025)
|
||||
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = ScaledLinear(pos_dim,
|
||||
num_heads * pos_head_dim,
|
||||
bias=False,
|
||||
initial_scale=0.05)
|
||||
|
||||
|
||||
# the following are for diagnosics only, see --print-diagnostics option
|
||||
self.copy_pos_query = Identity()
|
||||
self.copy_query = Identity()
|
||||
@ -1154,27 +1325,30 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
self,
|
||||
x: Tensor,
|
||||
pos_emb: Tensor,
|
||||
chunk_size: int = -1,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
attn_offset: Optional[Tensor] = None,
|
||||
pos_emb: Tensor,
|
||||
quadratic_pos_weight: Tensor,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Args:
|
||||
x: input of shape (seq_len, batch_size, embed_dim)
|
||||
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 2, pos_dim)
|
||||
chunk_size
|
||||
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
|
||||
are True in this mask will be ignored as sources in the attention weighting.
|
||||
attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
|
||||
interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
|
||||
saying which positions are allowed to attend to which other positions.
|
||||
|
||||
attn_offset: a Tensor of shape broadcasting with (batch_size, seq_len, seq_len),
|
||||
interpreted as (batch_size, tgt_seq_len, src_seq_len), if provided this
|
||||
contains values (probably <= 0) to be added to the logprobs of the attention;
|
||||
this may combine the log of 'weights' of ChooseDownsamplingModule with
|
||||
any attn_mask that enforces causality.
|
||||
pos_emb: a Tensor of shape broadcasting with (batch_size, seq_len, seq_len, pos_dim)
|
||||
(e.g. pos_dim=4), encoding relative positions.
|
||||
|
||||
Returns:
|
||||
a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
|
||||
interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
|
||||
"""
|
||||
x = self.in_proj(x)
|
||||
query_head_dim = self.query_head_dim
|
||||
pos_head_dim = self.pos_head_dim
|
||||
pos_dim = self.pos_dim
|
||||
num_heads = self.num_heads
|
||||
|
||||
seq_len, batch_size, _ = x.shape
|
||||
@ -1185,7 +1359,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
k = x[...,query_dim:2*query_dim]
|
||||
# p is the position-encoding query
|
||||
p = x[...,2*query_dim:]
|
||||
assert p.shape[-1] == num_heads * pos_head_dim
|
||||
assert p.shape[-1] == num_heads * pos_dim
|
||||
|
||||
|
||||
q = self.copy_query(q) # for diagnostics only, does nothing.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user