Use block-wise attention

This commit is contained in:
yaozengwei 2023-07-20 19:38:03 +08:00
parent 4ab7d61008
commit 80a14f93d3
3 changed files with 215 additions and 87 deletions

View File

@ -1602,6 +1602,33 @@ def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
return torch.cat((x, zeros), dim=-1) return torch.cat((x, zeros), dim=-1)
def unfold(
x: Tensor, x_pad: int, num_blocks: int, kernel: int, stride: int, padding: int
) -> Tensor:
"""
Args:
x: input of shape (seq_len, batch_size, channel)
Returns:
blocks: (kernel, batch_size * num_blocks, channel)
"""
seq_len, batch_size, channel = x.size()
x = x.permute(1, 2, 0) # (B, D, T)
x = nn.functional.pad(x, pad=(0, x_pad), value=0.0)
blocks = nn.functional.unfold(
x.unsqueeze(-1),
kernel_size=(kernel, 1),
padding=(padding, 0),
stride=(stride, 1),
) # (B, C * kernel, num_blocks)
blocks = blocks.reshape(batch_size, channel, kernel, num_blocks)
blocks = blocks.permute(2, 0, 3, 1)
blocks = blocks.reshape(kernel, batch_size * num_blocks, channel)
return blocks
def _test_whiten(): def _test_whiten():
for proportion in [0.1, 0.5, 10.0]: for proportion in [0.1, 0.5, 10.0]:
logging.info(f"_test_whiten(): proportion = {proportion}") logging.info(f"_test_whiten(): proportion = {proportion}")

View File

@ -187,6 +187,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Positional-encoding embedding dimension", help="Positional-encoding embedding dimension",
) )
parser.add_argument(
"--block-size",
type=int,
default="32",
help="Block size used in block-wise attention",
)
parser.add_argument( parser.add_argument(
"--encoder-unmasked-dim", "--encoder-unmasked-dim",
type=str, type=str,
@ -574,6 +581,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
num_heads=_to_int_tuple(params.num_heads), num_heads=_to_int_tuple(params.num_heads),
feedforward_dim=_to_int_tuple(params.feedforward_dim), feedforward_dim=_to_int_tuple(params.feedforward_dim),
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
block_size=params.block_size,
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0, warmup_batches=4000.0,
causal=params.causal, causal=params.causal,

View File

@ -39,6 +39,7 @@ from scaling import (
FloatLike, FloatLike,
limit_param_value, limit_param_value,
convert_num_channels, convert_num_channels,
unfold,
) )
from torch import Tensor, nn from torch import Tensor, nn
@ -105,6 +106,7 @@ class Zipformer2(EncoderInterface):
feedforward_dim: Union[int, Tuple[int]] = 1536, feedforward_dim: Union[int, Tuple[int]] = 1536,
cnn_module_kernel: Union[int, Tuple[int]] = 31, cnn_module_kernel: Union[int, Tuple[int]] = 31,
pos_dim: int = 192, pos_dim: int = 192,
block_size: int = 32,
dropout: FloatLike = None, # see code below for default dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0, warmup_batches: float = 4000.0,
causal: bool = False, causal: bool = False,
@ -140,6 +142,7 @@ class Zipformer2(EncoderInterface):
self.num_heads = num_heads = _to_tuple(num_heads) self.num_heads = num_heads = _to_tuple(num_heads)
feedforward_dim = _to_tuple(feedforward_dim) feedforward_dim = _to_tuple(feedforward_dim)
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
self.block_size = block_size
self.causal = causal self.causal = causal
self.chunk_size = chunk_size self.chunk_size = chunk_size
@ -153,6 +156,7 @@ class Zipformer2(EncoderInterface):
num_encoders = len(downsampling_factor) num_encoders = len(downsampling_factor)
for i in range(num_encoders): for i in range(num_encoders):
ds = downsampling_factor[i]
encoder_layer = Zipformer2EncoderLayer( encoder_layer = Zipformer2EncoderLayer(
embed_dim=encoder_dim[i], embed_dim=encoder_dim[i],
@ -164,6 +168,7 @@ class Zipformer2(EncoderInterface):
feedforward_dim=feedforward_dim[i], feedforward_dim=feedforward_dim[i],
dropout=dropout, dropout=dropout,
cnn_module_kernel=cnn_module_kernel[i], cnn_module_kernel=cnn_module_kernel[i],
block_size=block_size // ds,
causal=causal, causal=causal,
) )
@ -173,13 +178,14 @@ class Zipformer2(EncoderInterface):
encoder_layer, encoder_layer,
num_encoder_layers[i], num_encoder_layers[i],
pos_dim=pos_dim, pos_dim=pos_dim,
block_size=block_size // ds,
dropout=dropout, dropout=dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
) )
if downsampling_factor[i] != 1: if ds != 1:
encoder = DownsampledZipformer2Encoder( encoder = DownsampledZipformer2Encoder(
encoder, encoder,
dim=encoder_dim[i], dim=encoder_dim[i],
@ -536,6 +542,7 @@ class Zipformer2EncoderLayer(nn.Module):
feedforward_dim: int, feedforward_dim: int,
dropout: FloatLike = 0.1, dropout: FloatLike = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
block_size: int = 32,
causal: bool = False, causal: bool = False,
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
@ -569,14 +576,14 @@ class Zipformer2EncoderLayer(nn.Module):
self.self_attn_weights = RelPositionMultiheadAttentionWeights( self.self_attn_weights = RelPositionMultiheadAttentionWeights(
embed_dim, pos_dim=pos_dim, num_heads=num_heads, embed_dim, pos_dim=pos_dim, num_heads=num_heads,
query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, query_head_dim=query_head_dim, pos_head_dim=pos_head_dim,
dropout=0.0, block_size=block_size, dropout=0.0,
) )
self.self_attn1 = SelfAttention(embed_dim, num_heads, self.self_attn1 = SelfAttention(embed_dim, num_heads,
value_head_dim) value_head_dim, block_size=block_size)
self.self_attn2 = SelfAttention(embed_dim, num_heads, self.self_attn2 = SelfAttention(embed_dim, num_heads,
value_head_dim) value_head_dim, block_size=block_size)
self.feed_forward1 = FeedforwardModule(embed_dim, self.feed_forward1 = FeedforwardModule(embed_dim,
(feedforward_dim * 3) // 4, (feedforward_dim * 3) // 4,
@ -591,7 +598,8 @@ class Zipformer2EncoderLayer(nn.Module):
dropout) dropout)
self.nonlin_attention = NonlinAttention(embed_dim, self.nonlin_attention = NonlinAttention(embed_dim,
hidden_channels=3 * embed_dim // 4) hidden_channels=3 * embed_dim // 4,
block_size=block_size)
self.conv_module1 = ConvolutionModule(embed_dim, self.conv_module1 = ConvolutionModule(embed_dim,
cnn_module_kernel, cnn_module_kernel,
@ -917,6 +925,7 @@ class Zipformer2Encoder(nn.Module):
encoder_layer: nn.Module, encoder_layer: nn.Module,
num_layers: int, num_layers: int,
pos_dim: int, pos_dim: int,
block_size: int,
dropout: float, dropout: float,
warmup_begin: float, warmup_begin: float,
warmup_end: float, warmup_end: float,
@ -931,6 +940,7 @@ class Zipformer2Encoder(nn.Module):
[copy.deepcopy(encoder_layer) for i in range(num_layers)] [copy.deepcopy(encoder_layer) for i in range(num_layers)]
) )
self.num_layers = num_layers self.num_layers = num_layers
self.block_size = block_size
assert 0 <= warmup_begin <= warmup_end assert 0 <= warmup_begin <= warmup_end
@ -966,7 +976,7 @@ class Zipformer2Encoder(nn.Module):
Returns: a Tensor with the same shape as src. Returns: a Tensor with the same shape as src.
""" """
pos_emb = self.encoder_pos(src) pos_emb = self.encoder_pos(src, block_size=self.block_size)
output = src output = src
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
@ -1314,9 +1324,9 @@ class CompactRelPositionalEncoding(torch.nn.Module):
self.length_factor = length_factor self.length_factor = length_factor
self.extend_pe(torch.tensor(0.0).expand(max_len)) self.extend_pe(torch.tensor(0.0).expand(max_len))
def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings.""" """Reset the positional encodings."""
T = x.size(0) + left_context_len T = x.size(0)
if self.pe is not None: if self.pe is not None:
# self.pe contains both positive and negative parts # self.pe contains both positive and negative parts
@ -1361,25 +1371,25 @@ class CompactRelPositionalEncoding(torch.nn.Module):
self.pe = pe.to(dtype=x.dtype) self.pe = pe.to(dtype=x.dtype)
def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: def forward(self, x: Tensor, block_size: int = 0) -> Tensor:
"""Create positional encoding. """Create positional encoding.
Args: Args:
x (Tensor): Input tensor (time, batch, `*`). x (Tensor): Input tensor (time, batch, `*`).
left_context_len: (int): Length of cached left context. block_size (int):
Returns: Returns:
positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). positional embedding, of shape (1, 2*time-1, `*`) or (1, 4*block_size-1, `*`).
""" """
self.extend_pe(x, left_context_len) self.extend_pe(x)
x_size_left = x.size(0) + left_context_len rel_pos = 2 * block_size if block_size != 0 else x.size(0)
# length of positive side: x.size(0) + left_context_len # length of positive side: 2 * block_size
# length of negative side: x.size(0) # length of negative side: 2 * block_size
pos_emb = self.pe[ pos_emb = self.pe[
self.pe.size(0) // 2 self.pe.size(0) // 2
- x_size_left - rel_pos
+ 1 : self.pe.size(0) // 2 # noqa E203 + 1 : self.pe.size(0) // 2 # noqa E203
+ x.size(0), + rel_pos,
: :
] ]
pos_emb = pos_emb.unsqueeze(0) pos_emb = pos_emb.unsqueeze(0)
@ -1413,6 +1423,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
num_heads: int, num_heads: int,
query_head_dim: int, query_head_dim: int,
pos_head_dim: int, pos_head_dim: int,
block_size: int,
dropout: float = 0.0, dropout: float = 0.0,
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5),
(4000.0, 0.0)) (4000.0, 0.0))
@ -1422,6 +1433,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.query_head_dim = query_head_dim self.query_head_dim = query_head_dim
self.pos_head_dim = pos_head_dim self.pos_head_dim = pos_head_dim
self.block_size = block_size
self.dropout = dropout self.dropout = dropout
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
self.name = None # will be overwritten in training code; for diagnostics. self.name = None # will be overwritten in training code; for diagnostics.
@ -1478,16 +1490,19 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
r""" r"""
Args: Args:
x: input of shape (seq_len, batch_size, embed_dim) x: input of shape (seq_len, batch_size, embed_dim)
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) pos_emb: Positional embedding tensor, of shape (1, 4*block_size-1, pos_dim)
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
are True in this mask will be ignored as sources in the attention weighting. are True in this mask will be ignored as sources in the attention weighting.
attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
interpreted as ([batch_size,] tgt_seq_len, src_seq_len) interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
saying which positions are allowed to attend to which other positions. saying which positions are allowed to attend to which other positions.
Returns: Returns:
a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) a tensor of attention weights, of shape (hum_heads, batch_size * num_blocks, block_size, block_size * 3)
interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len),
where num_blocks = (seq_len + block_size - 1) // block_size.
""" """
assert attn_mask is None, "Not supported yet"
x = self.in_proj(x) x = self.in_proj(x)
query_head_dim = self.query_head_dim query_head_dim = self.query_head_dim
pos_head_dim = self.pos_head_dim pos_head_dim = self.pos_head_dim
@ -1508,16 +1523,31 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
p = self.copy_pos_query(p) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing.
q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) # divide into blocks by unfold function
p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) block_size = self.block_size
k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) num_blocks = (seq_len + block_size - 1) // block_size
pad_len = num_blocks * block_size - seq_len
# (kernel, batch_size * num_blocks, channel)
q_blocks = unfold(q, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0)
p_blocks = unfold(p, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0)
k_blocks = unfold(k, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size)
# time1 refers to target, time2 refers to source. # time1 refers to target, time2 refers to source.
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) time1 = q_blocks.size(0)
p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) time2 = k_blocks.size(0)
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) new_batch_size = batch_size * num_blocks
attn_scores = torch.matmul(q, k) q_blocks = q_blocks.reshape(time1, new_batch_size, num_heads, query_head_dim)
p_blocks = p_blocks.reshape(time1, new_batch_size, num_heads, pos_head_dim)
k_blocks = k_blocks.reshape(time2, new_batch_size, num_heads, query_head_dim)
q_blocks = q_blocks.permute(2, 1, 0, 3) # (head, new_batch, time1, query_head_dim)
p_blocks = p_blocks.permute(2, 1, 0, 3) # (head, new_batch, time1, pos_head_dim)
k_blocks = k_blocks.permute(2, 1, 3, 0) # (head, new_batch, d_k, time2)
# (head, new_batch, time1, time2)
attn_scores = torch.matmul(q_blocks, k_blocks)
use_pos_scores = False use_pos_scores = False
if torch.jit.is_scripting() or torch.jit.is_tracing(): if torch.jit.is_scripting() or torch.jit.is_tracing():
@ -1528,32 +1558,21 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
if use_pos_scores: if use_pos_scores:
pos_emb = self.linear_pos(pos_emb) pos_emb = self.linear_pos(pos_emb)
seq_len2 = 2 * seq_len - 1 pos_emb = pos_emb.reshape(1, time1 + time2 - 1, num_heads, pos_head_dim).permute(2, 0, 3, 1)
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) # pos shape now: (head, 1, pos_dim, time1+time2-1)
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, time1+time2-1) -> (head, batch, time1, time1+time2-1)
# [where seq_len2 represents relative position.] # [where seq_len2 represents relative position.]
pos_scores = torch.matmul(p, pos_emb) pos_scores = torch.matmul(p_blocks, pos_emb)
# the following .as_strided() expression converts the last axis of pos_scores from relative # the following .as_strided() expression converts the last axis of pos_scores from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or # to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be. # not, but let this code define which way round it is supposed to be.
if torch.jit.is_tracing(): pos_scores = pos_scores.as_strided((num_heads, new_batch_size, time1, time2),
(num_heads, batch_size, time1, n) = pos_scores.shape (pos_scores.stride(0),
rows = torch.arange(start=time1 - 1, end=-1, step=-1) pos_scores.stride(1),
cols = torch.arange(seq_len) pos_scores.stride(2)-pos_scores.stride(3),
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) pos_scores.stride(3)),
indexes = rows + cols storage_offset=pos_scores.stride(3) * (time1 - 1))
pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
else:
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
(pos_scores.stride(0),
pos_scores.stride(1),
pos_scores.stride(2)-pos_scores.stride(3),
pos_scores.stride(3)),
storage_offset=pos_scores.stride(3) * (seq_len - 1))
attn_scores = attn_scores + pos_scores attn_scores = attn_scores + pos_scores
@ -1577,9 +1596,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
penalty=1.0e-04, penalty=1.0e-04,
name=self.name) name=self.name)
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) assert attn_scores.shape == (num_heads, new_batch_size, time1, time2)
if attn_mask is not None: if attn_mask is not None:
# TODO:
assert attn_mask.dtype == torch.bool assert attn_mask.dtype == torch.bool
# use -1000 to avoid nan's where attn_mask and key_padding_mask make # use -1000 to avoid nan's where attn_mask and key_padding_mask make
# all scores zero. It's important that this be large enough that exp(-1000) # all scores zero. It's important that this be large enough that exp(-1000)
@ -1587,12 +1607,31 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# compares the final weights with zero. # compares the final weights with zero.
attn_scores = attn_scores.masked_fill(attn_mask, -1000) attn_scores = attn_scores.masked_fill(attn_mask, -1000)
if key_padding_mask is not None: assert key_padding_mask is not None
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
attn_scores = attn_scores.masked_fill( attn_offsets = (~key_padding_mask).float() # 0 at padding positions
key_padding_mask.unsqueeze(1),
-1000, # (seq_len, batch, 1)
) attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(-1)
# (time2, new_batch_size)
attn_offsets = unfold(
attn_offsets, pad_len, num_blocks,
kernel=block_size * 3, stride=block_size, padding=block_size,
).squeeze(-1)
# For the blocks are all padding
all_pad_mask = (attn_offsets.sum(dim=0, keepdim=True) == 0) # (1, new_batch_size)
all_pad_mask = all_pad_mask.unsqueeze(-1).unsqueeze(-1) # (1, new_batch_size, 1, 1)
attn_offsets = 1 - attn_offsets # 1 at padding positions
# attn_offsets[attn_offsets != 0] = float("-inf")
attn_offsets[attn_offsets != 0] = -1000
# attn_offsets = attn_offsets.masked_fill((attn_offsets != 0), -1000)
# (1, new_batch_size, 1, time2)
attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(1).unsqueeze(0)
attn_scores = attn_scores + attn_offsets
# We use our own version of softmax, defined in scaling.py, which should # We use our own version of softmax, defined in scaling.py, which should
# save a little of the memory used in backprop by, if we are in # save a little of the memory used in backprop by, if we are in
@ -1600,6 +1639,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# half-precision output for backprop purposes. # half-precision output for backprop purposes.
attn_weights = softmax(attn_scores, dim=-1) attn_weights = softmax(attn_scores, dim=-1)
# For the blocks are all padding
attn_weights = attn_weights.masked_fill(all_pad_mask, 0.0)
if torch.jit.is_scripting() or torch.jit.is_tracing(): if torch.jit.is_scripting() or torch.jit.is_tracing():
pass pass
elif random.random() < 0.001 and not self.training: elif random.random() < 0.001 and not self.training:
@ -1678,7 +1720,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# [where seq_len2 represents relative position.] # [where seq_len2 represents relative position.]
pos_scores = torch.matmul(p, pos_emb) pos_scores = torch.matmul(p, pos_emb)
if torch.jit.is_tracing(): if torch.jit.is_tracing():
(num_heads, batch_size, time1, n) = pos_scores.shape (num_heads, batch_size, time1, n) = pos_scores.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1) rows = torch.arange(start=time1 - 1, end=-1, step=-1)
@ -1743,8 +1785,10 @@ class SelfAttention(nn.Module):
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
value_head_dim: int, value_head_dim: int,
block_size: int,
) -> None: ) -> None:
super().__init__() super().__init__()
self.block_size = block_size
self.in_proj = nn.Linear(embed_dim, self.in_proj = nn.Linear(embed_dim,
num_heads * value_head_dim, num_heads * value_head_dim,
bias=True) bias=True)
@ -1766,27 +1810,45 @@ class SelfAttention(nn.Module):
""" """
Args: Args:
x: input tensor, of shape (seq_len, batch_size, embed_dim) x: input tensor, of shape (seq_len, batch_size, embed_dim)
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), attn_weights: a tensor of attention weights, of shape
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect (hum_heads, batch_size * num_blocks, block_size, block_size * 3)
attn_weights.sum(dim=-1) == 1. interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len),
where num_blocks = (seq_len + block_size - 1) // block_size.
Expect attn_weights.sum(dim=-1) == 1.
Returns: Returns:
a tensor with the same shape as x. a tensor with the same shape as x.
""" """
(seq_len, batch_size, embed_dim) = x.shape (seq_len, batch_size, embed_dim) = x.shape
num_heads = attn_weights.shape[0] num_heads = attn_weights.shape[0]
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
# divide into blocks by unfold function
block_size = self.block_size
num_blocks = (seq_len + block_size - 1) // block_size
pad_len = num_blocks * block_size - seq_len
new_batch_size = batch_size * num_blocks
time1 = block_size # target length
time2 = 3 * block_size # source length
assert attn_weights.shape == (num_heads, new_batch_size, time1, time2)
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, batch_size, seq_len, value_head_dim) # (time2, new_batch_size, channel)
value_head_dim = x.shape[-1] x_blocks = unfold(x, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size)
x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, new_batch_size, time2, value_head_dim)
value_head_dim = x_blocks.shape[-1]
# todo: see whether there is benefit in overriding matmul # todo: see whether there is benefit in overriding matmul
x = torch.matmul(attn_weights, x) x = torch.matmul(attn_weights, x_blocks)
# v: (num_heads, batch_size, seq_len, value_head_dim) # v: (num_heads, new_batch_size, time1, value_head_dim)
x = x.permute(2, 1, 0, 3).contiguous().view( x = x.reshape(num_heads, batch_size, num_blocks, time1, value_head_dim)
seq_len, batch_size, num_heads * value_head_dim) x = x.permute(2, 3, 1, 0, 4).contiguous().view(
num_blocks * time1, batch_size, num_heads * value_head_dim)
x = x[:seq_len] # (seq_len, batch_size, value_dim)
# returned value is of shape (seq_len, batch_size, embed_dim), like the input. # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
x = self.out_proj(x) x = self.out_proj(x)
@ -1896,10 +1958,12 @@ class NonlinAttention(nn.Module):
self, self,
channels: int, channels: int,
hidden_channels: int, hidden_channels: int,
block_size: int,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.block_size = block_size
self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
@ -1942,7 +2006,11 @@ class NonlinAttention(nn.Module):
""". """.
Args: Args:
x: a Tensor of shape (seq_len, batch_size, num_channels) x: a Tensor of shape (seq_len, batch_size, num_channels)
attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) attn_weights: a tensor of attention weights, of shape
(hum_heads, batch_size * num_blocks, block_size, block_size * 3)
interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len),
where num_blocks = (seq_len + block_size - 1) // block_size.
Expect attn_weights.sum(dim=-1) == 1.
Returns: Returns:
a Tensor with the same shape as x a Tensor with the same shape as x
""" """
@ -1965,13 +2033,31 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
(seq_len, batch_size, embed_dim) = x.shape (seq_len, batch_size, embed_dim) = x.shape
num_heads = attn_weights.shape[0] num_heads = attn_weights.shape[0]
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) # divide into blocks by unfold function
# now x: (num_heads, batch_size, seq_len, head_dim) block_size = self.block_size
x = torch.matmul(attn_weights, x) num_blocks = (seq_len + block_size - 1) // block_size
# now x: (num_heads, batch_size, seq_len, head_dim) pad_len = num_blocks * block_size - seq_len
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) new_batch_size = batch_size * num_blocks
time1 = block_size # target length
time2 = 3 * block_size # source length
assert attn_weights.shape == (num_heads, new_batch_size, time1, time2)
# (time2, new_batch_size, channel)
x_blocks = unfold(x, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size)
x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, new_batch_size, time2, head_dim)
x = torch.matmul(attn_weights, x_blocks)
# now x: (num_heads, new_batch_size, time1, head_dim)
x = x.reshape(num_heads, batch_size, num_blocks, time1, -1)
x = x.permute(2, 3, 1, 0, 4).contiguous().view(
num_blocks * time1, batch_size, embed_dim)
x = x[:seq_len] # (seq_len, batch_size, embed_dim)
y = self.identity2(y) y = self.identity2(y)
x = x * y x = x * y
@ -2220,30 +2306,37 @@ class ScalarMultiply(nn.Module):
def _test_zipformer_main(causal: bool = False): def _test_zipformer_main(causal: bool = False):
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs. # Just make sure the forward pass runs.
from icefall.utils import make_pad_mask
c = Zipformer2( c = Zipformer2(
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
downsampling_factor=(1, 2),
block_size=4,
causal=causal, causal=causal,
chunk_size=(4,) if causal else (-1,), chunk_size=(4,) if causal else (-1,),
left_context_frames=(64,) left_context_frames=(64,)
) )
batch_size = 5 batch_size = 2
seq_len = 20 seq_len = 14
# Just make sure the forward pass runs. # Just make sure the forward pass runs.
f = c( x = torch.randn(seq_len, batch_size, 64)
torch.randn(seq_len, batch_size, 64), lengths = torch.full((batch_size,), seq_len, dtype=torch.int64)
torch.full((batch_size,), seq_len, dtype=torch.int64), lengths[-1] = 1
) src_key_padding_mask = make_pad_mask(lengths)
f = c(x, lengths, src_key_padding_mask)
f[0].sum().backward() f[0].sum().backward()
c.eval() c.eval()
f = c(
torch.randn(seq_len, batch_size, 64), x = torch.randn(seq_len, batch_size, 64)
torch.full((batch_size,), seq_len, dtype=torch.int64), lengths = torch.full((batch_size,), seq_len, dtype=torch.int64)
) lengths[-1] = seq_len - 2
src_key_padding_mask = make_pad_mask(lengths)
f = c(x, lengths, src_key_padding_mask)
f # to remove flake8 warnings f # to remove flake8 warnings
print(f[0].sum())
if __name__ == "__main__": if __name__ == "__main__":
@ -2251,4 +2344,4 @@ if __name__ == "__main__":
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
_test_zipformer_main(False) _test_zipformer_main(False)
_test_zipformer_main(True) # _test_zipformer_main(True)