Do block-wise attention when seq_len is larger than 512, with block_size <= 512

This commit is contained in:
yaozengwei 2023-07-23 16:12:57 +08:00
parent ee485c02fc
commit 215541c7c5
2 changed files with 318 additions and 47 deletions

View File

@ -188,10 +188,10 @@ def add_model_arguments(parser: argparse.ArgumentParser):
) )
parser.add_argument( parser.add_argument(
"--block-size", "--max-block-size",
type=str, type=str,
default="32", default="512",
help="Block size used in block-wise attention; a single int or comma-separated list", help="Max block size used in block-wise attention; a single int or comma-separated list",
) )
parser.add_argument( parser.add_argument(
@ -581,7 +581,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
num_heads=_to_int_tuple(params.num_heads), num_heads=_to_int_tuple(params.num_heads),
feedforward_dim=_to_int_tuple(params.feedforward_dim), feedforward_dim=_to_int_tuple(params.feedforward_dim),
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
block_size=_to_int_tuple(params.block_size), max_block_size=_to_int_tuple(params.max_block_size),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0, warmup_batches=4000.0,
causal=params.causal, causal=params.causal,

View File

@ -106,7 +106,8 @@ class Zipformer2(EncoderInterface):
feedforward_dim: Union[int, Tuple[int]] = 1536, feedforward_dim: Union[int, Tuple[int]] = 1536,
cnn_module_kernel: Union[int, Tuple[int]] = 31, cnn_module_kernel: Union[int, Tuple[int]] = 31,
pos_dim: int = 192, pos_dim: int = 192,
block_size: Union[int, Tuple[int]] = 32, max_block_size: Union[int, Tuple[int]] = 512,
block_pad: int = 16,
dropout: FloatLike = None, # see code below for default dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0, warmup_batches: float = 4000.0,
causal: bool = False, causal: bool = False,
@ -142,7 +143,7 @@ class Zipformer2(EncoderInterface):
self.num_heads = num_heads = _to_tuple(num_heads) self.num_heads = num_heads = _to_tuple(num_heads)
feedforward_dim = _to_tuple(feedforward_dim) feedforward_dim = _to_tuple(feedforward_dim)
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
self.block_size = block_size = _to_tuple(block_size) self.max_block_size = max_block_size = _to_tuple(max_block_size)
self.causal = causal self.causal = causal
self.chunk_size = chunk_size self.chunk_size = chunk_size
@ -168,7 +169,6 @@ class Zipformer2(EncoderInterface):
feedforward_dim=feedforward_dim[i], feedforward_dim=feedforward_dim[i],
dropout=dropout, dropout=dropout,
cnn_module_kernel=cnn_module_kernel[i], cnn_module_kernel=cnn_module_kernel[i],
block_size=block_size[i],
causal=causal, causal=causal,
) )
@ -178,7 +178,8 @@ class Zipformer2(EncoderInterface):
encoder_layer, encoder_layer,
num_encoder_layers[i], num_encoder_layers[i],
pos_dim=pos_dim, pos_dim=pos_dim,
block_size=block_size[i], max_block_size=max_block_size[i],
block_pad=block_pad,
dropout=dropout, dropout=dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
@ -542,7 +543,6 @@ class Zipformer2EncoderLayer(nn.Module):
feedforward_dim: int, feedforward_dim: int,
dropout: FloatLike = 0.1, dropout: FloatLike = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
block_size: int = 32,
causal: bool = False, causal: bool = False,
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
@ -576,14 +576,14 @@ class Zipformer2EncoderLayer(nn.Module):
self.self_attn_weights = RelPositionMultiheadAttentionWeights( self.self_attn_weights = RelPositionMultiheadAttentionWeights(
embed_dim, pos_dim=pos_dim, num_heads=num_heads, embed_dim, pos_dim=pos_dim, num_heads=num_heads,
query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, query_head_dim=query_head_dim, pos_head_dim=pos_head_dim,
block_size=block_size, dropout=0.0, dropout=0.0,
) )
self.self_attn1 = SelfAttention(embed_dim, num_heads, self.self_attn1 = SelfAttention(embed_dim, num_heads,
value_head_dim, block_size=block_size) value_head_dim)
self.self_attn2 = SelfAttention(embed_dim, num_heads, self.self_attn2 = SelfAttention(embed_dim, num_heads,
value_head_dim, block_size=block_size) value_head_dim)
self.feed_forward1 = FeedforwardModule(embed_dim, self.feed_forward1 = FeedforwardModule(embed_dim,
(feedforward_dim * 3) // 4, (feedforward_dim * 3) // 4,
@ -598,8 +598,7 @@ class Zipformer2EncoderLayer(nn.Module):
dropout) dropout)
self.nonlin_attention = NonlinAttention(embed_dim, self.nonlin_attention = NonlinAttention(embed_dim,
hidden_channels=3 * embed_dim // 4, hidden_channels=3 * embed_dim // 4)
block_size=block_size)
self.conv_module1 = ConvolutionModule(embed_dim, self.conv_module1 = ConvolutionModule(embed_dim,
cnn_module_kernel, cnn_module_kernel,
@ -680,6 +679,8 @@ class Zipformer2EncoderLayer(nn.Module):
self, self,
src: Tensor, src: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
block_size: int = 0,
block_pad: int = 16,
chunk_size: int = -1, chunk_size: int = -1,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
@ -689,6 +690,8 @@ class Zipformer2EncoderLayer(nn.Module):
Args: Args:
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
block_size: size of block
block_pad: pad size at each side of block
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
feature_mask: something that broadcasts with src, that we'll multiply `src` feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
@ -710,12 +713,21 @@ class Zipformer2EncoderLayer(nn.Module):
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
# attn_weights: (num_heads, batch_size, seq_len, seq_len) # attn_weights: (num_heads, batch_size, seq_len, seq_len)
attn_weights = self.self_attn_weights( if block_size == 0:
src, attn_weights = self.self_attn_weights(
pos_emb=pos_emb, src,
attn_mask=attn_mask, pos_emb=pos_emb,
key_padding_mask=src_key_padding_mask, attn_mask=attn_mask,
) key_padding_mask=src_key_padding_mask,
)
else:
attn_weights = self.self_attn_weights.forward_block(
src,
pos_emb=pos_emb,
block_size=block_size,
block_pad=block_pad,
key_padding_mask=src_key_padding_mask,
)
src = src + self.feed_forward1(src) src = src + self.feed_forward1(src)
@ -733,11 +745,20 @@ class Zipformer2EncoderLayer(nn.Module):
selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype)
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) if block_size == 0:
na = self.nonlin_attention(src, selected_attn_weights)
else:
na = self.nonlin_attention.forward_block(
src, selected_attn_weights, block_size=block_size, block_pad=block_pad)
na = self.balancer_na(na)
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
self_attn = self.self_attn1(src, attn_weights) if block_size == 0:
self_attn = self.self_attn1(src, attn_weights)
else:
self_attn = self.self_attn1.forward_block(
src, attn_weights, block_size=block_size, block_pad=block_pad)
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
@ -759,7 +780,11 @@ class Zipformer2EncoderLayer(nn.Module):
# bypass in the middle of the layer. # bypass in the middle of the layer.
src = self.bypass_mid(src_orig, src) src = self.bypass_mid(src_orig, src)
self_attn = self.self_attn2(src, attn_weights) if block_size == 0:
self_attn = self.self_attn2(src, attn_weights)
else:
self_attn = self.self_attn2.forward_block(
src, attn_weights, block_size=block_size, block_pad=block_pad)
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
@ -925,10 +950,11 @@ class Zipformer2Encoder(nn.Module):
encoder_layer: nn.Module, encoder_layer: nn.Module,
num_layers: int, num_layers: int,
pos_dim: int, pos_dim: int,
block_size: int, max_block_size: int,
dropout: float, dropout: float,
warmup_begin: float, warmup_begin: float,
warmup_end: float, warmup_end: float,
block_pad: int = 16,
initial_layerdrop_rate: float = 0.5, initial_layerdrop_rate: float = 0.5,
final_layerdrop_rate: float = 0.05, final_layerdrop_rate: float = 0.05,
) -> None: ) -> None:
@ -940,7 +966,8 @@ class Zipformer2Encoder(nn.Module):
[copy.deepcopy(encoder_layer) for i in range(num_layers)] [copy.deepcopy(encoder_layer) for i in range(num_layers)]
) )
self.num_layers = num_layers self.num_layers = num_layers
self.block_size = block_size self.max_block_size = max_block_size
self.block_pad = block_pad
assert 0 <= warmup_begin <= warmup_end assert 0 <= warmup_begin <= warmup_end
@ -976,7 +1003,20 @@ class Zipformer2Encoder(nn.Module):
Returns: a Tensor with the same shape as src. Returns: a Tensor with the same shape as src.
""" """
pos_emb = self.encoder_pos(src, block_size=self.block_size) seq_len = src.size(0)
max_block_size = self.max_block_size
block_pad = self.block_pad
if seq_len > max_block_size:
num_blocks = math.ceil(seq_len / max_block_size)
block_size = math.ceil(seq_len / num_blocks)
pos_emb = self.encoder_pos(src, rel_pos=block_size + block_pad)
# if __name__ == "__main__":
if random.random() < 0.2:
logging.info(f"seq_len={seq_len}, block_size={block_size}")
else:
pos_emb = self.encoder_pos(src)
block_size = 0
output = src output = src
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
@ -986,6 +1026,8 @@ class Zipformer2Encoder(nn.Module):
output = mod( output = mod(
output, output,
pos_emb, pos_emb,
block_size=block_size,
block_pad=block_pad,
chunk_size=chunk_size, chunk_size=chunk_size,
attn_mask=attn_mask, attn_mask=attn_mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
@ -1371,7 +1413,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
self.pe = pe.to(dtype=x.dtype) self.pe = pe.to(dtype=x.dtype)
def forward(self, x: Tensor, block_size: int = 0) -> Tensor: def forward(self, x: Tensor, rel_pos: int = 0) -> Tensor:
"""Create positional encoding. """Create positional encoding.
Args: Args:
@ -1382,9 +1424,8 @@ class CompactRelPositionalEncoding(torch.nn.Module):
positional embedding, of shape (1, 2*time-1, `*`) or (1, 4*block_size-1, `*`). positional embedding, of shape (1, 2*time-1, `*`) or (1, 4*block_size-1, `*`).
""" """
self.extend_pe(x) self.extend_pe(x)
rel_pos = 2 * block_size if block_size != 0 else x.size(0) if rel_pos == 0:
# length of positive side: 2 * block_size rel_pos = x.size(0)
# length of negative side: 2 * block_size
pos_emb = self.pe[ pos_emb = self.pe[
self.pe.size(0) // 2 self.pe.size(0) // 2
- rel_pos - rel_pos
@ -1423,7 +1464,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
num_heads: int, num_heads: int,
query_head_dim: int, query_head_dim: int,
pos_head_dim: int, pos_head_dim: int,
block_size: int,
dropout: float = 0.0, dropout: float = 0.0,
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5),
(4000.0, 0.0)) (4000.0, 0.0))
@ -1433,7 +1473,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.query_head_dim = query_head_dim self.query_head_dim = query_head_dim
self.pos_head_dim = pos_head_dim self.pos_head_dim = pos_head_dim
self.block_size = block_size
self.dropout = dropout self.dropout = dropout
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
self.name = None # will be overwritten in training code; for diagnostics. self.name = None # will be overwritten in training code; for diagnostics.
@ -1486,11 +1525,158 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
pos_emb: Tensor, pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
) -> Tensor:
r"""
Args:
x: input of shape (seq_len, batch_size, embed_dim)
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
are True in this mask will be ignored as sources in the attention weighting.
attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
saying which positions are allowed to attend to which other positions.
Returns:
a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
"""
x = self.in_proj(x)
query_head_dim = self.query_head_dim
pos_head_dim = self.pos_head_dim
num_heads = self.num_heads
seq_len, batch_size, _ = x.shape
query_dim = query_head_dim * num_heads
# self-attention
q = x[...,0:query_dim]
k = x[...,query_dim:2*query_dim]
# p is the position-encoding query
p = x[...,2*query_dim:]
assert p.shape[-1] == num_heads * pos_head_dim
q = self.copy_query(q) # for diagnostics only, does nothing.
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
p = self.copy_pos_query(p) # for diagnostics only, does nothing.
q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
# time1 refers to target, time2 refers to source.
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
attn_scores = torch.matmul(q, k)
use_pos_scores = False
if torch.jit.is_scripting() or torch.jit.is_tracing():
# We can't put random.random() in the same line
use_pos_scores = True
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
use_pos_scores = True
if use_pos_scores:
pos_emb = self.linear_pos(pos_emb)
seq_len2 = 2 * seq_len - 1
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# [where seq_len2 represents relative position.]
pos_scores = torch.matmul(p, pos_emb)
# the following .as_strided() expression converts the last axis of pos_scores from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.
if torch.jit.is_tracing():
(num_heads, batch_size, time1, n) = pos_scores.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(seq_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
else:
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
(pos_scores.stride(0),
pos_scores.stride(1),
pos_scores.stride(2)-pos_scores.stride(3),
pos_scores.stride(3)),
storage_offset=pos_scores.stride(3) * (seq_len - 1))
attn_scores = attn_scores + pos_scores
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
elif self.training and random.random() < 0.1:
# This is a harder way of limiting the attention scores to not be
# too large. It incurs a penalty if any of them has an absolute
# value greater than 50.0. this should be outside the normal range
# of the attention scores. We use this mechanism instead of, say,
# something added to the loss function involving the entropy,
# because once the entropy gets very small gradients through the
# softmax can become very small, and we'd get zero derivatives. The
# choices of 1.0e-04 as the scale on the penalty makes this
# mechanism vulnerable to the absolute scale of the loss function,
# but we view this as a failsafe to avoid "implausible" parameter
# values rather than a regularization method that should be active
# under normal circumstances.
attn_scores = penalize_abs_values_gt(attn_scores,
limit=25.0,
penalty=1.0e-04,
name=self.name)
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
if attn_mask is not None:
assert attn_mask.dtype == torch.bool
# use -1000 to avoid nan's where attn_mask and key_padding_mask make
# all scores zero. It's important that this be large enough that exp(-1000)
# is exactly zero, for reasons related to const_attention_rate, it
# compares the final weights with zero.
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
attn_scores = attn_scores.masked_fill(
key_padding_mask.unsqueeze(1),
-1000,
)
# We use our own version of softmax, defined in scaling.py, which should
# save a little of the memory used in backprop by, if we are in
# automatic mixed precision mode (amp / autocast), by only storing the
# half-precision output for backprop purposes.
attn_weights = softmax(attn_scores, dim=-1)
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
elif random.random() < 0.001 and not self.training:
self._print_attn_entropy(attn_weights)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
return attn_weights
def forward_block(
self,
x: Tensor,
pos_emb: Tensor,
block_size: int,
block_pad: int,
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
r""" r"""
Args: Args:
x: input of shape (seq_len, batch_size, embed_dim) x: input of shape (seq_len, batch_size, embed_dim)
pos_emb: Positional embedding tensor, of shape (1, 4*block_size-1, pos_dim) pos_emb: Positional embedding tensor, of shape (1, 4*block_size-1, pos_dim)
block_size: size of block
block_pad: pad size at each side of block
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
are True in this mask will be ignored as sources in the attention weighting. are True in this mask will be ignored as sources in the attention weighting.
attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
@ -1524,14 +1710,13 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
p = self.copy_pos_query(p) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing.
# divide into blocks by unfold function # divide into blocks by unfold function
block_size = self.block_size
num_blocks = (seq_len + block_size - 1) // block_size num_blocks = (seq_len + block_size - 1) // block_size
pad_len = num_blocks * block_size - seq_len pad_len = num_blocks * block_size - seq_len
# (kernel, batch_size * num_blocks, channel) # (kernel, batch_size * num_blocks, channel)
q_blocks = unfold(q, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0) q_blocks = unfold(q, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0)
p_blocks = unfold(p, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0) p_blocks = unfold(p, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0)
k_blocks = unfold(k, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size) k_blocks = unfold(k, pad_len, num_blocks, kernel=block_size + 2 * block_pad, stride=block_size, padding=block_pad)
# time1 refers to target, time2 refers to source. # time1 refers to target, time2 refers to source.
time1 = q_blocks.size(0) time1 = q_blocks.size(0)
@ -1620,7 +1805,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# (time2, new_batch_size) # (time2, new_batch_size)
attn_offsets = unfold( attn_offsets = unfold(
attn_offsets, pad_len, num_blocks, attn_offsets, pad_len, num_blocks,
kernel=block_size * 3, stride=block_size, padding=block_size, kernel=block_size + 2 * block_pad, stride=block_size, padding=block_pad,
).squeeze(-1) ).squeeze(-1)
# Used for the blocks are all padding # Used for the blocks are all padding
@ -1787,10 +1972,8 @@ class SelfAttention(nn.Module):
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
value_head_dim: int, value_head_dim: int,
block_size: int,
) -> None: ) -> None:
super().__init__() super().__init__()
self.block_size = block_size
self.in_proj = nn.Linear(embed_dim, self.in_proj = nn.Linear(embed_dim,
num_heads * value_head_dim, num_heads * value_head_dim,
bias=True) bias=True)
@ -1808,6 +1991,44 @@ class SelfAttention(nn.Module):
self, self,
x: Tensor, x: Tensor,
attn_weights: Tensor, attn_weights: Tensor,
) -> Tensor:
"""
Args:
x: input tensor, of shape (seq_len, batch_size, embed_dim)
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
attn_weights.sum(dim=-1) == 1.
Returns:
a tensor with the same shape as x.
"""
(seq_len, batch_size, embed_dim) = x.shape
num_heads = attn_weights.shape[0]
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, batch_size, seq_len, value_head_dim)
value_head_dim = x.shape[-1]
# todo: see whether there is benefit in overriding matmul
x = torch.matmul(attn_weights, x)
# v: (num_heads, batch_size, seq_len, value_head_dim)
x = x.permute(2, 1, 0, 3).contiguous().view(
seq_len, batch_size, num_heads * value_head_dim)
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
x = self.out_proj(x)
x = self.whiten(x)
return x
def forward_block(
self,
x: Tensor,
attn_weights: Tensor,
block_size: int,
block_pad: int,
) -> Tensor: ) -> Tensor:
""" """
Args: Args:
@ -1817,6 +2038,8 @@ class SelfAttention(nn.Module):
interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len), interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len),
where num_blocks = (seq_len + block_size - 1) // block_size. where num_blocks = (seq_len + block_size - 1) // block_size.
Expect attn_weights.sum(dim=-1) == 1. Expect attn_weights.sum(dim=-1) == 1.
block_size: size of block
block_pad: pad size at each side of block
Returns: Returns:
a tensor with the same shape as x. a tensor with the same shape as x.
""" """
@ -1824,19 +2047,18 @@ class SelfAttention(nn.Module):
num_heads = attn_weights.shape[0] num_heads = attn_weights.shape[0]
# divide into blocks by unfold function # divide into blocks by unfold function
block_size = self.block_size
num_blocks = (seq_len + block_size - 1) // block_size num_blocks = (seq_len + block_size - 1) // block_size
pad_len = num_blocks * block_size - seq_len pad_len = num_blocks * block_size - seq_len
new_batch_size = batch_size * num_blocks new_batch_size = batch_size * num_blocks
time1 = block_size # target length time1 = block_size # target length
time2 = 3 * block_size # source length time2 = block_size + 2 * block_pad # source length
assert attn_weights.shape == (num_heads, new_batch_size, time1, time2) assert attn_weights.shape == (num_heads, new_batch_size, time1, time2)
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
# (time2, new_batch_size, channel) # (time2, new_batch_size, channel)
x_blocks = unfold(x, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size) x_blocks = unfold(x, pad_len, num_blocks, kernel=time2, stride=block_size, padding=block_pad)
x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3) x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, new_batch_size, time2, value_head_dim) # now x: (num_heads, new_batch_size, time2, value_head_dim)
@ -1960,12 +2182,10 @@ class NonlinAttention(nn.Module):
self, self,
channels: int, channels: int,
hidden_channels: int, hidden_channels: int,
block_size: int,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.block_size = block_size
self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
@ -2004,6 +2224,55 @@ class NonlinAttention(nn.Module):
self, self,
x: Tensor, x: Tensor,
attn_weights: Tensor, attn_weights: Tensor,
) -> Tensor:
""".
Args:
x: a Tensor of shape (seq_len, batch_size, num_channels)
attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
Returns:
a Tensor with the same shape as x
"""
x = self.in_proj(x)
(seq_len, batch_size, _) = x.shape
hidden_channels = self.hidden_channels
s, x, y = x.chunk(3, dim=-1)
# s will go through tanh.
s = self.balancer(s)
s = self.tanh(s)
s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
x = self.whiten1(x)
x = x * s
x = self.identity1(x) # diagnostics only, it's the identity.
(seq_len, batch_size, embed_dim) = x.shape
num_heads = attn_weights.shape[0]
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, batch_size, seq_len, head_dim)
x = torch.matmul(attn_weights, x)
# now x: (num_heads, batch_size, seq_len, head_dim)
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
y = self.identity2(y)
x = x * y
x = self.identity3(x)
x = self.out_proj(x)
x = self.whiten2(x)
return x
def forward_block(
self,
x: Tensor,
attn_weights: Tensor,
block_size: int,
block_pad: int,
) -> Tensor: ) -> Tensor:
""". """.
Args: Args:
@ -2013,6 +2282,8 @@ class NonlinAttention(nn.Module):
interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len), interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len),
where num_blocks = (seq_len + block_size - 1) // block_size. where num_blocks = (seq_len + block_size - 1) // block_size.
Expect attn_weights.sum(dim=-1) == 1. Expect attn_weights.sum(dim=-1) == 1.
block_size: size of block
block_pad: pad size at each side of block
Returns: Returns:
a Tensor with the same shape as x a Tensor with the same shape as x
""" """
@ -2037,17 +2308,16 @@ class NonlinAttention(nn.Module):
num_heads = attn_weights.shape[0] num_heads = attn_weights.shape[0]
# divide into blocks by unfold function # divide into blocks by unfold function
block_size = self.block_size
num_blocks = (seq_len + block_size - 1) // block_size num_blocks = (seq_len + block_size - 1) // block_size
pad_len = num_blocks * block_size - seq_len pad_len = num_blocks * block_size - seq_len
new_batch_size = batch_size * num_blocks new_batch_size = batch_size * num_blocks
time1 = block_size # target length time1 = block_size # target length
time2 = 3 * block_size # source length time2 = block_size + 2 * block_pad # source length
assert attn_weights.shape == (num_heads, new_batch_size, time1, time2) assert attn_weights.shape == (num_heads, new_batch_size, time1, time2)
# (time2, new_batch_size, channel) # (time2, new_batch_size, channel)
x_blocks = unfold(x, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size) x_blocks = unfold(x, pad_len, num_blocks, kernel=time2, stride=block_size, padding=block_pad)
x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3) x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, new_batch_size, time2, head_dim) # now x: (num_heads, new_batch_size, time2, head_dim)
@ -2315,13 +2585,14 @@ def _test_zipformer_main(causal: bool = False):
c = Zipformer2( c = Zipformer2(
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
downsampling_factor=(1, 2), downsampling_factor=(1, 2),
block_size=4, max_block_size=14,
block_pad=1,
causal=causal, causal=causal,
chunk_size=(4,) if causal else (-1,), chunk_size=(4,) if causal else (-1,),
left_context_frames=(64,) left_context_frames=(64,)
) )
batch_size = 2 batch_size = 2
seq_len = 14 seq_len = 29
# Just make sure the forward pass runs. # Just make sure the forward pass runs.
x = torch.randn(seq_len, batch_size, 64) x = torch.randn(seq_len, batch_size, 64)