mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 17:44:20 +00:00
Do block-wise attention when seq_len is larger than 512, with block_size <= 512
This commit is contained in:
parent
ee485c02fc
commit
215541c7c5
@ -188,10 +188,10 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--block-size",
|
||||
"--max-block-size",
|
||||
type=str,
|
||||
default="32",
|
||||
help="Block size used in block-wise attention; a single int or comma-separated list",
|
||||
default="512",
|
||||
help="Max block size used in block-wise attention; a single int or comma-separated list",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -581,7 +581,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
num_heads=_to_int_tuple(params.num_heads),
|
||||
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
||||
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
|
||||
block_size=_to_int_tuple(params.block_size),
|
||||
max_block_size=_to_int_tuple(params.max_block_size),
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
warmup_batches=4000.0,
|
||||
causal=params.causal,
|
||||
|
@ -106,7 +106,8 @@ class Zipformer2(EncoderInterface):
|
||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
||||
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
||||
pos_dim: int = 192,
|
||||
block_size: Union[int, Tuple[int]] = 32,
|
||||
max_block_size: Union[int, Tuple[int]] = 512,
|
||||
block_pad: int = 16,
|
||||
dropout: FloatLike = None, # see code below for default
|
||||
warmup_batches: float = 4000.0,
|
||||
causal: bool = False,
|
||||
@ -142,7 +143,7 @@ class Zipformer2(EncoderInterface):
|
||||
self.num_heads = num_heads = _to_tuple(num_heads)
|
||||
feedforward_dim = _to_tuple(feedforward_dim)
|
||||
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
||||
self.block_size = block_size = _to_tuple(block_size)
|
||||
self.max_block_size = max_block_size = _to_tuple(max_block_size)
|
||||
|
||||
self.causal = causal
|
||||
self.chunk_size = chunk_size
|
||||
@ -168,7 +169,6 @@ class Zipformer2(EncoderInterface):
|
||||
feedforward_dim=feedforward_dim[i],
|
||||
dropout=dropout,
|
||||
cnn_module_kernel=cnn_module_kernel[i],
|
||||
block_size=block_size[i],
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
@ -178,7 +178,8 @@ class Zipformer2(EncoderInterface):
|
||||
encoder_layer,
|
||||
num_encoder_layers[i],
|
||||
pos_dim=pos_dim,
|
||||
block_size=block_size[i],
|
||||
max_block_size=max_block_size[i],
|
||||
block_pad=block_pad,
|
||||
dropout=dropout,
|
||||
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
||||
@ -542,7 +543,6 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
feedforward_dim: int,
|
||||
dropout: FloatLike = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
block_size: int = 32,
|
||||
causal: bool = False,
|
||||
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||
@ -576,14 +576,14 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
|
||||
embed_dim, pos_dim=pos_dim, num_heads=num_heads,
|
||||
query_head_dim=query_head_dim, pos_head_dim=pos_head_dim,
|
||||
block_size=block_size, dropout=0.0,
|
||||
dropout=0.0,
|
||||
)
|
||||
|
||||
self.self_attn1 = SelfAttention(embed_dim, num_heads,
|
||||
value_head_dim, block_size=block_size)
|
||||
value_head_dim)
|
||||
|
||||
self.self_attn2 = SelfAttention(embed_dim, num_heads,
|
||||
value_head_dim, block_size=block_size)
|
||||
value_head_dim)
|
||||
|
||||
self.feed_forward1 = FeedforwardModule(embed_dim,
|
||||
(feedforward_dim * 3) // 4,
|
||||
@ -598,8 +598,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
dropout)
|
||||
|
||||
self.nonlin_attention = NonlinAttention(embed_dim,
|
||||
hidden_channels=3 * embed_dim // 4,
|
||||
block_size=block_size)
|
||||
hidden_channels=3 * embed_dim // 4)
|
||||
|
||||
self.conv_module1 = ConvolutionModule(embed_dim,
|
||||
cnn_module_kernel,
|
||||
@ -680,6 +679,8 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
block_size: int = 0,
|
||||
block_pad: int = 16,
|
||||
chunk_size: int = -1,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
@ -689,6 +690,8 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
Args:
|
||||
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
||||
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
|
||||
block_size: size of block
|
||||
block_pad: pad size at each side of block
|
||||
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
||||
@ -710,12 +713,21 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
|
||||
|
||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||
attn_weights = self.self_attn_weights(
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=attn_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
if block_size == 0:
|
||||
attn_weights = self.self_attn_weights(
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=attn_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
else:
|
||||
attn_weights = self.self_attn_weights.forward_block(
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
block_size=block_size,
|
||||
block_pad=block_pad,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
src = src + self.feed_forward1(src)
|
||||
|
||||
@ -733,11 +745,20 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype)
|
||||
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
||||
|
||||
na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
|
||||
if block_size == 0:
|
||||
na = self.nonlin_attention(src, selected_attn_weights)
|
||||
else:
|
||||
na = self.nonlin_attention.forward_block(
|
||||
src, selected_attn_weights, block_size=block_size, block_pad=block_pad)
|
||||
na = self.balancer_na(na)
|
||||
|
||||
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
|
||||
|
||||
self_attn = self.self_attn1(src, attn_weights)
|
||||
if block_size == 0:
|
||||
self_attn = self.self_attn1(src, attn_weights)
|
||||
else:
|
||||
self_attn = self.self_attn1.forward_block(
|
||||
src, attn_weights, block_size=block_size, block_pad=block_pad)
|
||||
|
||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||
|
||||
@ -759,7 +780,11 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
# bypass in the middle of the layer.
|
||||
src = self.bypass_mid(src_orig, src)
|
||||
|
||||
self_attn = self.self_attn2(src, attn_weights)
|
||||
if block_size == 0:
|
||||
self_attn = self.self_attn2(src, attn_weights)
|
||||
else:
|
||||
self_attn = self.self_attn2.forward_block(
|
||||
src, attn_weights, block_size=block_size, block_pad=block_pad)
|
||||
|
||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||
|
||||
@ -925,10 +950,11 @@ class Zipformer2Encoder(nn.Module):
|
||||
encoder_layer: nn.Module,
|
||||
num_layers: int,
|
||||
pos_dim: int,
|
||||
block_size: int,
|
||||
max_block_size: int,
|
||||
dropout: float,
|
||||
warmup_begin: float,
|
||||
warmup_end: float,
|
||||
block_pad: int = 16,
|
||||
initial_layerdrop_rate: float = 0.5,
|
||||
final_layerdrop_rate: float = 0.05,
|
||||
) -> None:
|
||||
@ -940,7 +966,8 @@ class Zipformer2Encoder(nn.Module):
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
self.block_size = block_size
|
||||
self.max_block_size = max_block_size
|
||||
self.block_pad = block_pad
|
||||
|
||||
assert 0 <= warmup_begin <= warmup_end
|
||||
|
||||
@ -976,7 +1003,20 @@ class Zipformer2Encoder(nn.Module):
|
||||
|
||||
Returns: a Tensor with the same shape as src.
|
||||
"""
|
||||
pos_emb = self.encoder_pos(src, block_size=self.block_size)
|
||||
seq_len = src.size(0)
|
||||
max_block_size = self.max_block_size
|
||||
block_pad = self.block_pad
|
||||
if seq_len > max_block_size:
|
||||
num_blocks = math.ceil(seq_len / max_block_size)
|
||||
block_size = math.ceil(seq_len / num_blocks)
|
||||
pos_emb = self.encoder_pos(src, rel_pos=block_size + block_pad)
|
||||
# if __name__ == "__main__":
|
||||
if random.random() < 0.2:
|
||||
logging.info(f"seq_len={seq_len}, block_size={block_size}")
|
||||
else:
|
||||
pos_emb = self.encoder_pos(src)
|
||||
block_size = 0
|
||||
|
||||
output = src
|
||||
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
@ -986,6 +1026,8 @@ class Zipformer2Encoder(nn.Module):
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
block_size=block_size,
|
||||
block_pad=block_pad,
|
||||
chunk_size=chunk_size,
|
||||
attn_mask=attn_mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
@ -1371,7 +1413,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
|
||||
self.pe = pe.to(dtype=x.dtype)
|
||||
|
||||
def forward(self, x: Tensor, block_size: int = 0) -> Tensor:
|
||||
def forward(self, x: Tensor, rel_pos: int = 0) -> Tensor:
|
||||
"""Create positional encoding.
|
||||
|
||||
Args:
|
||||
@ -1382,9 +1424,8 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
positional embedding, of shape (1, 2*time-1, `*`) or (1, 4*block_size-1, `*`).
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
rel_pos = 2 * block_size if block_size != 0 else x.size(0)
|
||||
# length of positive side: 2 * block_size
|
||||
# length of negative side: 2 * block_size
|
||||
if rel_pos == 0:
|
||||
rel_pos = x.size(0)
|
||||
pos_emb = self.pe[
|
||||
self.pe.size(0) // 2
|
||||
- rel_pos
|
||||
@ -1423,7 +1464,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
num_heads: int,
|
||||
query_head_dim: int,
|
||||
pos_head_dim: int,
|
||||
block_size: int,
|
||||
dropout: float = 0.0,
|
||||
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5),
|
||||
(4000.0, 0.0))
|
||||
@ -1433,7 +1473,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.query_head_dim = query_head_dim
|
||||
self.pos_head_dim = pos_head_dim
|
||||
self.block_size = block_size
|
||||
self.dropout = dropout
|
||||
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
||||
self.name = None # will be overwritten in training code; for diagnostics.
|
||||
@ -1486,11 +1525,158 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
pos_emb: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Args:
|
||||
x: input of shape (seq_len, batch_size, embed_dim)
|
||||
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
|
||||
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
|
||||
are True in this mask will be ignored as sources in the attention weighting.
|
||||
attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
|
||||
interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
|
||||
saying which positions are allowed to attend to which other positions.
|
||||
Returns:
|
||||
a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
|
||||
interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
|
||||
"""
|
||||
x = self.in_proj(x)
|
||||
query_head_dim = self.query_head_dim
|
||||
pos_head_dim = self.pos_head_dim
|
||||
num_heads = self.num_heads
|
||||
|
||||
seq_len, batch_size, _ = x.shape
|
||||
|
||||
query_dim = query_head_dim * num_heads
|
||||
|
||||
# self-attention
|
||||
q = x[...,0:query_dim]
|
||||
k = x[...,query_dim:2*query_dim]
|
||||
# p is the position-encoding query
|
||||
p = x[...,2*query_dim:]
|
||||
assert p.shape[-1] == num_heads * pos_head_dim
|
||||
|
||||
q = self.copy_query(q) # for diagnostics only, does nothing.
|
||||
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
|
||||
p = self.copy_pos_query(p) # for diagnostics only, does nothing.
|
||||
|
||||
q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
|
||||
p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
|
||||
k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
|
||||
|
||||
# time1 refers to target, time2 refers to source.
|
||||
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
|
||||
p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
|
||||
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
|
||||
|
||||
attn_scores = torch.matmul(q, k)
|
||||
|
||||
use_pos_scores = False
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
# We can't put random.random() in the same line
|
||||
use_pos_scores = True
|
||||
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||
use_pos_scores = True
|
||||
|
||||
if use_pos_scores:
|
||||
pos_emb = self.linear_pos(pos_emb)
|
||||
seq_len2 = 2 * seq_len - 1
|
||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
|
||||
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||
|
||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||
# [where seq_len2 represents relative position.]
|
||||
pos_scores = torch.matmul(p, pos_emb)
|
||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# not, but let this code define which way round it is supposed to be.
|
||||
if torch.jit.is_tracing():
|
||||
(num_heads, batch_size, time1, n) = pos_scores.shape
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(seq_len)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
pos_scores = pos_scores.reshape(-1, n)
|
||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
|
||||
else:
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
|
||||
attn_scores = attn_scores + pos_scores
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
pass
|
||||
elif self.training and random.random() < 0.1:
|
||||
# This is a harder way of limiting the attention scores to not be
|
||||
# too large. It incurs a penalty if any of them has an absolute
|
||||
# value greater than 50.0. this should be outside the normal range
|
||||
# of the attention scores. We use this mechanism instead of, say,
|
||||
# something added to the loss function involving the entropy,
|
||||
# because once the entropy gets very small gradients through the
|
||||
# softmax can become very small, and we'd get zero derivatives. The
|
||||
# choices of 1.0e-04 as the scale on the penalty makes this
|
||||
# mechanism vulnerable to the absolute scale of the loss function,
|
||||
# but we view this as a failsafe to avoid "implausible" parameter
|
||||
# values rather than a regularization method that should be active
|
||||
# under normal circumstances.
|
||||
attn_scores = penalize_abs_values_gt(attn_scores,
|
||||
limit=25.0,
|
||||
penalty=1.0e-04,
|
||||
name=self.name)
|
||||
|
||||
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
|
||||
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.dtype == torch.bool
|
||||
# use -1000 to avoid nan's where attn_mask and key_padding_mask make
|
||||
# all scores zero. It's important that this be large enough that exp(-1000)
|
||||
# is exactly zero, for reasons related to const_attention_rate, it
|
||||
# compares the final weights with zero.
|
||||
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
|
||||
attn_scores = attn_scores.masked_fill(
|
||||
key_padding_mask.unsqueeze(1),
|
||||
-1000,
|
||||
)
|
||||
|
||||
# We use our own version of softmax, defined in scaling.py, which should
|
||||
# save a little of the memory used in backprop by, if we are in
|
||||
# automatic mixed precision mode (amp / autocast), by only storing the
|
||||
# half-precision output for backprop purposes.
|
||||
attn_weights = softmax(attn_scores, dim=-1)
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
pass
|
||||
elif random.random() < 0.001 and not self.training:
|
||||
self._print_attn_entropy(attn_weights)
|
||||
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
|
||||
return attn_weights
|
||||
|
||||
def forward_block(
|
||||
self,
|
||||
x: Tensor,
|
||||
pos_emb: Tensor,
|
||||
block_size: int,
|
||||
block_pad: int,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Args:
|
||||
x: input of shape (seq_len, batch_size, embed_dim)
|
||||
pos_emb: Positional embedding tensor, of shape (1, 4*block_size-1, pos_dim)
|
||||
block_size: size of block
|
||||
block_pad: pad size at each side of block
|
||||
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
|
||||
are True in this mask will be ignored as sources in the attention weighting.
|
||||
attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
|
||||
@ -1524,14 +1710,13 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
p = self.copy_pos_query(p) # for diagnostics only, does nothing.
|
||||
|
||||
# divide into blocks by unfold function
|
||||
block_size = self.block_size
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
pad_len = num_blocks * block_size - seq_len
|
||||
|
||||
# (kernel, batch_size * num_blocks, channel)
|
||||
q_blocks = unfold(q, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0)
|
||||
p_blocks = unfold(p, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0)
|
||||
k_blocks = unfold(k, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size)
|
||||
k_blocks = unfold(k, pad_len, num_blocks, kernel=block_size + 2 * block_pad, stride=block_size, padding=block_pad)
|
||||
|
||||
# time1 refers to target, time2 refers to source.
|
||||
time1 = q_blocks.size(0)
|
||||
@ -1620,7 +1805,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# (time2, new_batch_size)
|
||||
attn_offsets = unfold(
|
||||
attn_offsets, pad_len, num_blocks,
|
||||
kernel=block_size * 3, stride=block_size, padding=block_size,
|
||||
kernel=block_size + 2 * block_pad, stride=block_size, padding=block_pad,
|
||||
).squeeze(-1)
|
||||
|
||||
# Used for the blocks are all padding
|
||||
@ -1787,10 +1972,8 @@ class SelfAttention(nn.Module):
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
value_head_dim: int,
|
||||
block_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.block_size = block_size
|
||||
self.in_proj = nn.Linear(embed_dim,
|
||||
num_heads * value_head_dim,
|
||||
bias=True)
|
||||
@ -1808,6 +1991,44 @@ class SelfAttention(nn.Module):
|
||||
self,
|
||||
x: Tensor,
|
||||
attn_weights: Tensor,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: input tensor, of shape (seq_len, batch_size, embed_dim)
|
||||
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
|
||||
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
|
||||
attn_weights.sum(dim=-1) == 1.
|
||||
Returns:
|
||||
a tensor with the same shape as x.
|
||||
"""
|
||||
(seq_len, batch_size, embed_dim) = x.shape
|
||||
num_heads = attn_weights.shape[0]
|
||||
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
|
||||
|
||||
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
|
||||
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
||||
# now x: (num_heads, batch_size, seq_len, value_head_dim)
|
||||
value_head_dim = x.shape[-1]
|
||||
|
||||
# todo: see whether there is benefit in overriding matmul
|
||||
x = torch.matmul(attn_weights, x)
|
||||
# v: (num_heads, batch_size, seq_len, value_head_dim)
|
||||
|
||||
x = x.permute(2, 1, 0, 3).contiguous().view(
|
||||
seq_len, batch_size, num_heads * value_head_dim)
|
||||
|
||||
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
|
||||
x = self.out_proj(x)
|
||||
x = self.whiten(x)
|
||||
|
||||
return x
|
||||
|
||||
def forward_block(
|
||||
self,
|
||||
x: Tensor,
|
||||
attn_weights: Tensor,
|
||||
block_size: int,
|
||||
block_pad: int,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -1817,6 +2038,8 @@ class SelfAttention(nn.Module):
|
||||
interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len),
|
||||
where num_blocks = (seq_len + block_size - 1) // block_size.
|
||||
Expect attn_weights.sum(dim=-1) == 1.
|
||||
block_size: size of block
|
||||
block_pad: pad size at each side of block
|
||||
Returns:
|
||||
a tensor with the same shape as x.
|
||||
"""
|
||||
@ -1824,19 +2047,18 @@ class SelfAttention(nn.Module):
|
||||
num_heads = attn_weights.shape[0]
|
||||
|
||||
# divide into blocks by unfold function
|
||||
block_size = self.block_size
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
pad_len = num_blocks * block_size - seq_len
|
||||
new_batch_size = batch_size * num_blocks
|
||||
time1 = block_size # target length
|
||||
time2 = 3 * block_size # source length
|
||||
time2 = block_size + 2 * block_pad # source length
|
||||
|
||||
assert attn_weights.shape == (num_heads, new_batch_size, time1, time2)
|
||||
|
||||
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
|
||||
|
||||
# (time2, new_batch_size, channel)
|
||||
x_blocks = unfold(x, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size)
|
||||
x_blocks = unfold(x, pad_len, num_blocks, kernel=time2, stride=block_size, padding=block_pad)
|
||||
|
||||
x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
||||
# now x: (num_heads, new_batch_size, time2, value_head_dim)
|
||||
@ -1960,12 +2182,10 @@ class NonlinAttention(nn.Module):
|
||||
self,
|
||||
channels: int,
|
||||
hidden_channels: int,
|
||||
block_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_channels = hidden_channels
|
||||
self.block_size = block_size
|
||||
|
||||
self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
|
||||
|
||||
@ -2004,6 +2224,55 @@ class NonlinAttention(nn.Module):
|
||||
self,
|
||||
x: Tensor,
|
||||
attn_weights: Tensor,
|
||||
) -> Tensor:
|
||||
""".
|
||||
Args:
|
||||
x: a Tensor of shape (seq_len, batch_size, num_channels)
|
||||
attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
Returns:
|
||||
a Tensor with the same shape as x
|
||||
"""
|
||||
x = self.in_proj(x)
|
||||
|
||||
(seq_len, batch_size, _) = x.shape
|
||||
hidden_channels = self.hidden_channels
|
||||
|
||||
s, x, y = x.chunk(3, dim=-1)
|
||||
|
||||
# s will go through tanh.
|
||||
|
||||
s = self.balancer(s)
|
||||
s = self.tanh(s)
|
||||
|
||||
s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
|
||||
x = self.whiten1(x)
|
||||
x = x * s
|
||||
x = self.identity1(x) # diagnostics only, it's the identity.
|
||||
|
||||
(seq_len, batch_size, embed_dim) = x.shape
|
||||
num_heads = attn_weights.shape[0]
|
||||
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
|
||||
|
||||
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
||||
# now x: (num_heads, batch_size, seq_len, head_dim)
|
||||
x = torch.matmul(attn_weights, x)
|
||||
# now x: (num_heads, batch_size, seq_len, head_dim)
|
||||
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
||||
|
||||
y = self.identity2(y)
|
||||
x = x * y
|
||||
x = self.identity3(x)
|
||||
|
||||
x = self.out_proj(x)
|
||||
x = self.whiten2(x)
|
||||
return x
|
||||
|
||||
def forward_block(
|
||||
self,
|
||||
x: Tensor,
|
||||
attn_weights: Tensor,
|
||||
block_size: int,
|
||||
block_pad: int,
|
||||
) -> Tensor:
|
||||
""".
|
||||
Args:
|
||||
@ -2013,6 +2282,8 @@ class NonlinAttention(nn.Module):
|
||||
interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len),
|
||||
where num_blocks = (seq_len + block_size - 1) // block_size.
|
||||
Expect attn_weights.sum(dim=-1) == 1.
|
||||
block_size: size of block
|
||||
block_pad: pad size at each side of block
|
||||
Returns:
|
||||
a Tensor with the same shape as x
|
||||
"""
|
||||
@ -2037,17 +2308,16 @@ class NonlinAttention(nn.Module):
|
||||
num_heads = attn_weights.shape[0]
|
||||
|
||||
# divide into blocks by unfold function
|
||||
block_size = self.block_size
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
pad_len = num_blocks * block_size - seq_len
|
||||
new_batch_size = batch_size * num_blocks
|
||||
time1 = block_size # target length
|
||||
time2 = 3 * block_size # source length
|
||||
time2 = block_size + 2 * block_pad # source length
|
||||
|
||||
assert attn_weights.shape == (num_heads, new_batch_size, time1, time2)
|
||||
|
||||
# (time2, new_batch_size, channel)
|
||||
x_blocks = unfold(x, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size)
|
||||
x_blocks = unfold(x, pad_len, num_blocks, kernel=time2, stride=block_size, padding=block_pad)
|
||||
|
||||
x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
||||
# now x: (num_heads, new_batch_size, time2, head_dim)
|
||||
@ -2315,13 +2585,14 @@ def _test_zipformer_main(causal: bool = False):
|
||||
c = Zipformer2(
|
||||
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
||||
downsampling_factor=(1, 2),
|
||||
block_size=4,
|
||||
max_block_size=14,
|
||||
block_pad=1,
|
||||
causal=causal,
|
||||
chunk_size=(4,) if causal else (-1,),
|
||||
left_context_frames=(64,)
|
||||
)
|
||||
batch_size = 2
|
||||
seq_len = 14
|
||||
seq_len = 29
|
||||
|
||||
# Just make sure the forward pass runs.
|
||||
x = torch.randn(seq_len, batch_size, 64)
|
||||
|
Loading…
x
Reference in New Issue
Block a user