Use block-wise attention

This commit is contained in:
yaozengwei 2023-07-20 19:38:03 +08:00
parent 4ab7d61008
commit 80a14f93d3
3 changed files with 215 additions and 87 deletions

View File

@ -1602,6 +1602,33 @@ def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
return torch.cat((x, zeros), dim=-1)
def unfold(
x: Tensor, x_pad: int, num_blocks: int, kernel: int, stride: int, padding: int
) -> Tensor:
"""
Args:
x: input of shape (seq_len, batch_size, channel)
Returns:
blocks: (kernel, batch_size * num_blocks, channel)
"""
seq_len, batch_size, channel = x.size()
x = x.permute(1, 2, 0) # (B, D, T)
x = nn.functional.pad(x, pad=(0, x_pad), value=0.0)
blocks = nn.functional.unfold(
x.unsqueeze(-1),
kernel_size=(kernel, 1),
padding=(padding, 0),
stride=(stride, 1),
) # (B, C * kernel, num_blocks)
blocks = blocks.reshape(batch_size, channel, kernel, num_blocks)
blocks = blocks.permute(2, 0, 3, 1)
blocks = blocks.reshape(kernel, batch_size * num_blocks, channel)
return blocks
def _test_whiten():
for proportion in [0.1, 0.5, 10.0]:
logging.info(f"_test_whiten(): proportion = {proportion}")

View File

@ -187,6 +187,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Positional-encoding embedding dimension",
)
parser.add_argument(
"--block-size",
type=int,
default="32",
help="Block size used in block-wise attention",
)
parser.add_argument(
"--encoder-unmasked-dim",
type=str,
@ -574,6 +581,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
num_heads=_to_int_tuple(params.num_heads),
feedforward_dim=_to_int_tuple(params.feedforward_dim),
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
block_size=params.block_size,
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
causal=params.causal,

View File

@ -39,6 +39,7 @@ from scaling import (
FloatLike,
limit_param_value,
convert_num_channels,
unfold,
)
from torch import Tensor, nn
@ -105,6 +106,7 @@ class Zipformer2(EncoderInterface):
feedforward_dim: Union[int, Tuple[int]] = 1536,
cnn_module_kernel: Union[int, Tuple[int]] = 31,
pos_dim: int = 192,
block_size: int = 32,
dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0,
causal: bool = False,
@ -140,6 +142,7 @@ class Zipformer2(EncoderInterface):
self.num_heads = num_heads = _to_tuple(num_heads)
feedforward_dim = _to_tuple(feedforward_dim)
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
self.block_size = block_size
self.causal = causal
self.chunk_size = chunk_size
@ -153,6 +156,7 @@ class Zipformer2(EncoderInterface):
num_encoders = len(downsampling_factor)
for i in range(num_encoders):
ds = downsampling_factor[i]
encoder_layer = Zipformer2EncoderLayer(
embed_dim=encoder_dim[i],
@ -164,6 +168,7 @@ class Zipformer2(EncoderInterface):
feedforward_dim=feedforward_dim[i],
dropout=dropout,
cnn_module_kernel=cnn_module_kernel[i],
block_size=block_size // ds,
causal=causal,
)
@ -173,13 +178,14 @@ class Zipformer2(EncoderInterface):
encoder_layer,
num_encoder_layers[i],
pos_dim=pos_dim,
block_size=block_size // ds,
dropout=dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
)
if downsampling_factor[i] != 1:
if ds != 1:
encoder = DownsampledZipformer2Encoder(
encoder,
dim=encoder_dim[i],
@ -536,6 +542,7 @@ class Zipformer2EncoderLayer(nn.Module):
feedforward_dim: int,
dropout: FloatLike = 0.1,
cnn_module_kernel: int = 31,
block_size: int = 32,
causal: bool = False,
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
@ -569,14 +576,14 @@ class Zipformer2EncoderLayer(nn.Module):
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
embed_dim, pos_dim=pos_dim, num_heads=num_heads,
query_head_dim=query_head_dim, pos_head_dim=pos_head_dim,
dropout=0.0,
block_size=block_size, dropout=0.0,
)
self.self_attn1 = SelfAttention(embed_dim, num_heads,
value_head_dim)
value_head_dim, block_size=block_size)
self.self_attn2 = SelfAttention(embed_dim, num_heads,
value_head_dim)
value_head_dim, block_size=block_size)
self.feed_forward1 = FeedforwardModule(embed_dim,
(feedforward_dim * 3) // 4,
@ -591,7 +598,8 @@ class Zipformer2EncoderLayer(nn.Module):
dropout)
self.nonlin_attention = NonlinAttention(embed_dim,
hidden_channels=3 * embed_dim // 4)
hidden_channels=3 * embed_dim // 4,
block_size=block_size)
self.conv_module1 = ConvolutionModule(embed_dim,
cnn_module_kernel,
@ -917,6 +925,7 @@ class Zipformer2Encoder(nn.Module):
encoder_layer: nn.Module,
num_layers: int,
pos_dim: int,
block_size: int,
dropout: float,
warmup_begin: float,
warmup_end: float,
@ -931,6 +940,7 @@ class Zipformer2Encoder(nn.Module):
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
self.block_size = block_size
assert 0 <= warmup_begin <= warmup_end
@ -966,7 +976,7 @@ class Zipformer2Encoder(nn.Module):
Returns: a Tensor with the same shape as src.
"""
pos_emb = self.encoder_pos(src)
pos_emb = self.encoder_pos(src, block_size=self.block_size)
output = src
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
@ -1314,9 +1324,9 @@ class CompactRelPositionalEncoding(torch.nn.Module):
self.length_factor = length_factor
self.extend_pe(torch.tensor(0.0).expand(max_len))
def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings."""
T = x.size(0) + left_context_len
T = x.size(0)
if self.pe is not None:
# self.pe contains both positive and negative parts
@ -1361,25 +1371,25 @@ class CompactRelPositionalEncoding(torch.nn.Module):
self.pe = pe.to(dtype=x.dtype)
def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
def forward(self, x: Tensor, block_size: int = 0) -> Tensor:
"""Create positional encoding.
Args:
x (Tensor): Input tensor (time, batch, `*`).
left_context_len: (int): Length of cached left context.
block_size (int):
Returns:
positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
positional embedding, of shape (1, 2*time-1, `*`) or (1, 4*block_size-1, `*`).
"""
self.extend_pe(x, left_context_len)
x_size_left = x.size(0) + left_context_len
# length of positive side: x.size(0) + left_context_len
# length of negative side: x.size(0)
self.extend_pe(x)
rel_pos = 2 * block_size if block_size != 0 else x.size(0)
# length of positive side: 2 * block_size
# length of negative side: 2 * block_size
pos_emb = self.pe[
self.pe.size(0) // 2
- x_size_left
- rel_pos
+ 1 : self.pe.size(0) // 2 # noqa E203
+ x.size(0),
+ rel_pos,
:
]
pos_emb = pos_emb.unsqueeze(0)
@ -1413,6 +1423,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
num_heads: int,
query_head_dim: int,
pos_head_dim: int,
block_size: int,
dropout: float = 0.0,
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5),
(4000.0, 0.0))
@ -1422,6 +1433,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self.num_heads = num_heads
self.query_head_dim = query_head_dim
self.pos_head_dim = pos_head_dim
self.block_size = block_size
self.dropout = dropout
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
self.name = None # will be overwritten in training code; for diagnostics.
@ -1478,16 +1490,19 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
r"""
Args:
x: input of shape (seq_len, batch_size, embed_dim)
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
pos_emb: Positional embedding tensor, of shape (1, 4*block_size-1, pos_dim)
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
are True in this mask will be ignored as sources in the attention weighting.
attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
saying which positions are allowed to attend to which other positions.
Returns:
a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
a tensor of attention weights, of shape (hum_heads, batch_size * num_blocks, block_size, block_size * 3)
interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len),
where num_blocks = (seq_len + block_size - 1) // block_size.
"""
assert attn_mask is None, "Not supported yet"
x = self.in_proj(x)
query_head_dim = self.query_head_dim
pos_head_dim = self.pos_head_dim
@ -1508,16 +1523,31 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
p = self.copy_pos_query(p) # for diagnostics only, does nothing.
q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
# divide into blocks by unfold function
block_size = self.block_size
num_blocks = (seq_len + block_size - 1) // block_size
pad_len = num_blocks * block_size - seq_len
# (kernel, batch_size * num_blocks, channel)
q_blocks = unfold(q, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0)
p_blocks = unfold(p, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0)
k_blocks = unfold(k, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size)
# time1 refers to target, time2 refers to source.
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
time1 = q_blocks.size(0)
time2 = k_blocks.size(0)
new_batch_size = batch_size * num_blocks
attn_scores = torch.matmul(q, k)
q_blocks = q_blocks.reshape(time1, new_batch_size, num_heads, query_head_dim)
p_blocks = p_blocks.reshape(time1, new_batch_size, num_heads, pos_head_dim)
k_blocks = k_blocks.reshape(time2, new_batch_size, num_heads, query_head_dim)
q_blocks = q_blocks.permute(2, 1, 0, 3) # (head, new_batch, time1, query_head_dim)
p_blocks = p_blocks.permute(2, 1, 0, 3) # (head, new_batch, time1, pos_head_dim)
k_blocks = k_blocks.permute(2, 1, 3, 0) # (head, new_batch, d_k, time2)
# (head, new_batch, time1, time2)
attn_scores = torch.matmul(q_blocks, k_blocks)
use_pos_scores = False
if torch.jit.is_scripting() or torch.jit.is_tracing():
@ -1528,32 +1558,21 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
if use_pos_scores:
pos_emb = self.linear_pos(pos_emb)
seq_len2 = 2 * seq_len - 1
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
pos_emb = pos_emb.reshape(1, time1 + time2 - 1, num_heads, pos_head_dim).permute(2, 0, 3, 1)
# pos shape now: (head, 1, pos_dim, time1+time2-1)
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, time1+time2-1) -> (head, batch, time1, time1+time2-1)
# [where seq_len2 represents relative position.]
pos_scores = torch.matmul(p, pos_emb)
pos_scores = torch.matmul(p_blocks, pos_emb)
# the following .as_strided() expression converts the last axis of pos_scores from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.
if torch.jit.is_tracing():
(num_heads, batch_size, time1, n) = pos_scores.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(seq_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
else:
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
pos_scores = pos_scores.as_strided((num_heads, new_batch_size, time1, time2),
(pos_scores.stride(0),
pos_scores.stride(1),
pos_scores.stride(2)-pos_scores.stride(3),
pos_scores.stride(3)),
storage_offset=pos_scores.stride(3) * (seq_len - 1))
storage_offset=pos_scores.stride(3) * (time1 - 1))
attn_scores = attn_scores + pos_scores
@ -1577,9 +1596,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
penalty=1.0e-04,
name=self.name)
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
assert attn_scores.shape == (num_heads, new_batch_size, time1, time2)
if attn_mask is not None:
# TODO:
assert attn_mask.dtype == torch.bool
# use -1000 to avoid nan's where attn_mask and key_padding_mask make
# all scores zero. It's important that this be large enough that exp(-1000)
@ -1587,12 +1607,31 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# compares the final weights with zero.
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
if key_padding_mask is not None:
assert key_padding_mask is not None
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
attn_scores = attn_scores.masked_fill(
key_padding_mask.unsqueeze(1),
-1000,
)
attn_offsets = (~key_padding_mask).float() # 0 at padding positions
# (seq_len, batch, 1)
attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(-1)
# (time2, new_batch_size)
attn_offsets = unfold(
attn_offsets, pad_len, num_blocks,
kernel=block_size * 3, stride=block_size, padding=block_size,
).squeeze(-1)
# For the blocks are all padding
all_pad_mask = (attn_offsets.sum(dim=0, keepdim=True) == 0) # (1, new_batch_size)
all_pad_mask = all_pad_mask.unsqueeze(-1).unsqueeze(-1) # (1, new_batch_size, 1, 1)
attn_offsets = 1 - attn_offsets # 1 at padding positions
# attn_offsets[attn_offsets != 0] = float("-inf")
attn_offsets[attn_offsets != 0] = -1000
# attn_offsets = attn_offsets.masked_fill((attn_offsets != 0), -1000)
# (1, new_batch_size, 1, time2)
attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(1).unsqueeze(0)
attn_scores = attn_scores + attn_offsets
# We use our own version of softmax, defined in scaling.py, which should
# save a little of the memory used in backprop by, if we are in
@ -1600,6 +1639,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# half-precision output for backprop purposes.
attn_weights = softmax(attn_scores, dim=-1)
# For the blocks are all padding
attn_weights = attn_weights.masked_fill(all_pad_mask, 0.0)
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
elif random.random() < 0.001 and not self.training:
@ -1743,8 +1785,10 @@ class SelfAttention(nn.Module):
embed_dim: int,
num_heads: int,
value_head_dim: int,
block_size: int,
) -> None:
super().__init__()
self.block_size = block_size
self.in_proj = nn.Linear(embed_dim,
num_heads * value_head_dim,
bias=True)
@ -1766,27 +1810,45 @@ class SelfAttention(nn.Module):
"""
Args:
x: input tensor, of shape (seq_len, batch_size, embed_dim)
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
attn_weights.sum(dim=-1) == 1.
attn_weights: a tensor of attention weights, of shape
(hum_heads, batch_size * num_blocks, block_size, block_size * 3)
interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len),
where num_blocks = (seq_len + block_size - 1) // block_size.
Expect attn_weights.sum(dim=-1) == 1.
Returns:
a tensor with the same shape as x.
"""
(seq_len, batch_size, embed_dim) = x.shape
num_heads = attn_weights.shape[0]
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
# divide into blocks by unfold function
block_size = self.block_size
num_blocks = (seq_len + block_size - 1) // block_size
pad_len = num_blocks * block_size - seq_len
new_batch_size = batch_size * num_blocks
time1 = block_size # target length
time2 = 3 * block_size # source length
assert attn_weights.shape == (num_heads, new_batch_size, time1, time2)
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, batch_size, seq_len, value_head_dim)
value_head_dim = x.shape[-1]
# (time2, new_batch_size, channel)
x_blocks = unfold(x, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size)
x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, new_batch_size, time2, value_head_dim)
value_head_dim = x_blocks.shape[-1]
# todo: see whether there is benefit in overriding matmul
x = torch.matmul(attn_weights, x)
# v: (num_heads, batch_size, seq_len, value_head_dim)
x = torch.matmul(attn_weights, x_blocks)
# v: (num_heads, new_batch_size, time1, value_head_dim)
x = x.permute(2, 1, 0, 3).contiguous().view(
seq_len, batch_size, num_heads * value_head_dim)
x = x.reshape(num_heads, batch_size, num_blocks, time1, value_head_dim)
x = x.permute(2, 3, 1, 0, 4).contiguous().view(
num_blocks * time1, batch_size, num_heads * value_head_dim)
x = x[:seq_len] # (seq_len, batch_size, value_dim)
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
x = self.out_proj(x)
@ -1896,10 +1958,12 @@ class NonlinAttention(nn.Module):
self,
channels: int,
hidden_channels: int,
block_size: int,
) -> None:
super().__init__()
self.hidden_channels = hidden_channels
self.block_size = block_size
self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
@ -1942,7 +2006,11 @@ class NonlinAttention(nn.Module):
""".
Args:
x: a Tensor of shape (seq_len, batch_size, num_channels)
attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
attn_weights: a tensor of attention weights, of shape
(hum_heads, batch_size * num_blocks, block_size, block_size * 3)
interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len),
where num_blocks = (seq_len + block_size - 1) // block_size.
Expect attn_weights.sum(dim=-1) == 1.
Returns:
a Tensor with the same shape as x
"""
@ -1965,13 +2033,31 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
(seq_len, batch_size, embed_dim) = x.shape
num_heads = attn_weights.shape[0]
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, batch_size, seq_len, head_dim)
x = torch.matmul(attn_weights, x)
# now x: (num_heads, batch_size, seq_len, head_dim)
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
# divide into blocks by unfold function
block_size = self.block_size
num_blocks = (seq_len + block_size - 1) // block_size
pad_len = num_blocks * block_size - seq_len
new_batch_size = batch_size * num_blocks
time1 = block_size # target length
time2 = 3 * block_size # source length
assert attn_weights.shape == (num_heads, new_batch_size, time1, time2)
# (time2, new_batch_size, channel)
x_blocks = unfold(x, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size)
x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, new_batch_size, time2, head_dim)
x = torch.matmul(attn_weights, x_blocks)
# now x: (num_heads, new_batch_size, time1, head_dim)
x = x.reshape(num_heads, batch_size, num_blocks, time1, -1)
x = x.permute(2, 3, 1, 0, 4).contiguous().view(
num_blocks * time1, batch_size, embed_dim)
x = x[:seq_len] # (seq_len, batch_size, embed_dim)
y = self.identity2(y)
x = x * y
@ -2220,30 +2306,37 @@ class ScalarMultiply(nn.Module):
def _test_zipformer_main(causal: bool = False):
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
from icefall.utils import make_pad_mask
c = Zipformer2(
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
downsampling_factor=(1, 2),
block_size=4,
causal=causal,
chunk_size=(4,) if causal else (-1,),
left_context_frames=(64,)
)
batch_size = 5
seq_len = 20
batch_size = 2
seq_len = 14
# Just make sure the forward pass runs.
f = c(
torch.randn(seq_len, batch_size, 64),
torch.full((batch_size,), seq_len, dtype=torch.int64),
)
x = torch.randn(seq_len, batch_size, 64)
lengths = torch.full((batch_size,), seq_len, dtype=torch.int64)
lengths[-1] = 1
src_key_padding_mask = make_pad_mask(lengths)
f = c(x, lengths, src_key_padding_mask)
f[0].sum().backward()
c.eval()
f = c(
torch.randn(seq_len, batch_size, 64),
torch.full((batch_size,), seq_len, dtype=torch.int64),
)
x = torch.randn(seq_len, batch_size, 64)
lengths = torch.full((batch_size,), seq_len, dtype=torch.int64)
lengths[-1] = seq_len - 2
src_key_padding_mask = make_pad_mask(lengths)
f = c(x, lengths, src_key_padding_mask)
f # to remove flake8 warnings
print(f[0].sum())
if __name__ == "__main__":
@ -2251,4 +2344,4 @@ if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_zipformer_main(False)
_test_zipformer_main(True)
# _test_zipformer_main(True)