mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Merge 3dc33515c0bcba749a775cf08b8aba546763fb66 into 3199058194a48d45aeee740f2aa9bdbef0bec29d
This commit is contained in:
commit
9ceffa4db1
@ -1602,6 +1602,61 @@ def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
|
|||||||
return torch.cat((x, zeros), dim=-1)
|
return torch.cat((x, zeros), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def unfold(
|
||||||
|
x: Tensor, x_pad: int, num_blocks: int, kernel: int, stride: int, padding: int
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: input of shape (seq_len, batch_size, channel)
|
||||||
|
Returns:
|
||||||
|
blocks: (kernel, batch_size * num_blocks, channel)
|
||||||
|
"""
|
||||||
|
seq_len, batch_size, channel = x.size()
|
||||||
|
x = x.permute(1, 2, 0) # (batch_size, channel, seq_len)
|
||||||
|
|
||||||
|
x = nn.functional.pad(x, pad=(0, x_pad), value=0.0)
|
||||||
|
|
||||||
|
blocks = nn.functional.unfold(
|
||||||
|
x.unsqueeze(-1),
|
||||||
|
kernel_size=(kernel, 1),
|
||||||
|
padding=(padding, 0),
|
||||||
|
stride=(stride, 1),
|
||||||
|
) # (B, C * kernel, num_blocks)
|
||||||
|
blocks = blocks.reshape(batch_size, channel, kernel, num_blocks)
|
||||||
|
blocks = blocks.permute(2, 0, 3, 1)
|
||||||
|
blocks = blocks.reshape(kernel, batch_size * num_blocks, channel)
|
||||||
|
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
def fold(
|
||||||
|
blocks: Tensor, seq_len: int, x_pad: int, num_blocks: int, kernel: int, stride: int, padding: int
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
blocks: (kernel, batch_size * num_blocks, channel)
|
||||||
|
Returns:
|
||||||
|
x: (seq_len, batch_size, channel)
|
||||||
|
"""
|
||||||
|
batch_size = blocks.size(1) // num_blocks
|
||||||
|
channel = blocks.size(2)
|
||||||
|
|
||||||
|
blocks = blocks.reshape(kernel, batch_size, num_blocks, channel)
|
||||||
|
blocks = blocks.permute(1, 3, 0, 2).reshape(batch_size, channel * kernel, num_blocks)
|
||||||
|
|
||||||
|
x = nn.functional.fold(
|
||||||
|
blocks,
|
||||||
|
output_size=(seq_len + x_pad, 1),
|
||||||
|
kernel_size=(kernel, 1),
|
||||||
|
padding=(padding, 0),
|
||||||
|
stride=(stride, 1),
|
||||||
|
)
|
||||||
|
x = x.squeeze(-1).permute(2, 0, 1)
|
||||||
|
x = x[:seq_len] # (seq_len, batch_size, channel)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _test_whiten():
|
def _test_whiten():
|
||||||
for proportion in [0.1, 0.5, 10.0]:
|
for proportion in [0.1, 0.5, 10.0]:
|
||||||
logging.info(f"_test_whiten(): proportion = {proportion}")
|
logging.info(f"_test_whiten(): proportion = {proportion}")
|
||||||
|
@ -187,6 +187,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Positional-encoding embedding dimension",
|
help="Positional-encoding embedding dimension",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-block-size",
|
||||||
|
type=str,
|
||||||
|
default="512",
|
||||||
|
help="Max block size used in block-wise attention; a single int or comma-separated list",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-unmasked-dim",
|
"--encoder-unmasked-dim",
|
||||||
type=str,
|
type=str,
|
||||||
@ -574,6 +581,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
num_heads=_to_int_tuple(params.num_heads),
|
num_heads=_to_int_tuple(params.num_heads),
|
||||||
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
||||||
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
|
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
|
||||||
|
max_block_size=_to_int_tuple(params.max_block_size),
|
||||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||||
warmup_batches=4000.0,
|
warmup_batches=4000.0,
|
||||||
causal=params.causal,
|
causal=params.causal,
|
||||||
|
@ -39,6 +39,8 @@ from scaling import (
|
|||||||
FloatLike,
|
FloatLike,
|
||||||
limit_param_value,
|
limit_param_value,
|
||||||
convert_num_channels,
|
convert_num_channels,
|
||||||
|
fold,
|
||||||
|
unfold,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -105,6 +107,8 @@ class Zipformer2(EncoderInterface):
|
|||||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
||||||
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
||||||
pos_dim: int = 192,
|
pos_dim: int = 192,
|
||||||
|
max_block_size: Union[int, Tuple[int]] = 512,
|
||||||
|
block_pad: int = 16,
|
||||||
dropout: FloatLike = None, # see code below for default
|
dropout: FloatLike = None, # see code below for default
|
||||||
warmup_batches: float = 4000.0,
|
warmup_batches: float = 4000.0,
|
||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
@ -140,6 +144,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
self.num_heads = num_heads = _to_tuple(num_heads)
|
self.num_heads = num_heads = _to_tuple(num_heads)
|
||||||
feedforward_dim = _to_tuple(feedforward_dim)
|
feedforward_dim = _to_tuple(feedforward_dim)
|
||||||
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
||||||
|
self.max_block_size = max_block_size = _to_tuple(max_block_size)
|
||||||
|
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
@ -153,6 +158,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
|
|
||||||
num_encoders = len(downsampling_factor)
|
num_encoders = len(downsampling_factor)
|
||||||
for i in range(num_encoders):
|
for i in range(num_encoders):
|
||||||
|
ds = downsampling_factor[i]
|
||||||
|
|
||||||
encoder_layer = Zipformer2EncoderLayer(
|
encoder_layer = Zipformer2EncoderLayer(
|
||||||
embed_dim=encoder_dim[i],
|
embed_dim=encoder_dim[i],
|
||||||
@ -173,13 +179,15 @@ class Zipformer2(EncoderInterface):
|
|||||||
encoder_layer,
|
encoder_layer,
|
||||||
num_encoder_layers[i],
|
num_encoder_layers[i],
|
||||||
pos_dim=pos_dim,
|
pos_dim=pos_dim,
|
||||||
|
max_block_size=max_block_size[i],
|
||||||
|
block_pad=block_pad,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
||||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
||||||
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
||||||
)
|
)
|
||||||
|
|
||||||
if downsampling_factor[i] != 1:
|
if ds != 1:
|
||||||
encoder = DownsampledZipformer2Encoder(
|
encoder = DownsampledZipformer2Encoder(
|
||||||
encoder,
|
encoder,
|
||||||
dim=encoder_dim[i],
|
dim=encoder_dim[i],
|
||||||
@ -674,6 +682,8 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
chunk_size: int = -1,
|
chunk_size: int = -1,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
attn_offsets: Optional[Tensor] = None,
|
||||||
|
all_pad_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
@ -681,6 +691,8 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
||||||
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
|
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
|
||||||
|
block_size: size of block
|
||||||
|
block_pad: pad size at each side of block
|
||||||
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
||||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||||
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
||||||
@ -706,7 +718,8 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
src,
|
src,
|
||||||
pos_emb=pos_emb,
|
pos_emb=pos_emb,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
attn_offsets=attn_offsets,
|
||||||
|
all_pad_mask=all_pad_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
src = src + self.feed_forward1(src)
|
src = src + self.feed_forward1(src)
|
||||||
@ -725,7 +738,8 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype)
|
selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype)
|
||||||
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
||||||
|
|
||||||
na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
|
na = self.nonlin_attention(src, selected_attn_weights)
|
||||||
|
na = self.balancer_na(na)
|
||||||
|
|
||||||
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
|
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
|
||||||
|
|
||||||
@ -917,9 +931,11 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
encoder_layer: nn.Module,
|
encoder_layer: nn.Module,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
pos_dim: int,
|
pos_dim: int,
|
||||||
|
max_block_size: int,
|
||||||
dropout: float,
|
dropout: float,
|
||||||
warmup_begin: float,
|
warmup_begin: float,
|
||||||
warmup_end: float,
|
warmup_end: float,
|
||||||
|
block_pad: int = 16,
|
||||||
initial_layerdrop_rate: float = 0.5,
|
initial_layerdrop_rate: float = 0.5,
|
||||||
final_layerdrop_rate: float = 0.05,
|
final_layerdrop_rate: float = 0.05,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -931,6 +947,8 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||||
)
|
)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
self.max_block_size = max_block_size
|
||||||
|
self.block_pad = block_pad
|
||||||
|
|
||||||
assert 0 <= warmup_begin <= warmup_end
|
assert 0 <= warmup_begin <= warmup_end
|
||||||
|
|
||||||
@ -957,7 +975,7 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
||||||
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
||||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||||
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
by at every layer: if a Tensor, likely of shape (1, batch_size, embedding_dim)
|
||||||
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
||||||
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
||||||
True means masked position. May be None.
|
True means masked position. May be None.
|
||||||
@ -966,6 +984,70 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
|
|
||||||
Returns: a Tensor with the same shape as src.
|
Returns: a Tensor with the same shape as src.
|
||||||
"""
|
"""
|
||||||
|
seq_len, batch_size, channel = src.size()
|
||||||
|
max_block_size = self.max_block_size
|
||||||
|
block_pad = self.block_pad
|
||||||
|
|
||||||
|
if seq_len > max_block_size:
|
||||||
|
# divide into blocks with overlaps
|
||||||
|
num_blocks = math.ceil(seq_len / max_block_size)
|
||||||
|
block_size = math.ceil(seq_len / num_blocks)
|
||||||
|
pad_len = num_blocks * block_size - seq_len
|
||||||
|
kernel_size = block_size + 2 * block_pad
|
||||||
|
if random.random() < 0.2 or __name__ == "__main__":
|
||||||
|
logging.info(f"seq_len={seq_len}, block_size={block_size}, pad_len={pad_len}")
|
||||||
|
|
||||||
|
# (block_size + 2 * block_pad, batch_size * num_blocks, channel)
|
||||||
|
src = unfold(
|
||||||
|
src, pad_len, num_blocks,
|
||||||
|
kernel=kernel_size, stride=block_size, padding=block_pad
|
||||||
|
)
|
||||||
|
|
||||||
|
# Used to mask out the padding positions
|
||||||
|
attn_offsets = torch.ones(batch_size, seq_len, device=src.device)
|
||||||
|
|
||||||
|
if src_key_padding_mask is not None:
|
||||||
|
assert src_key_padding_mask.shape == (batch_size, seq_len), src_key_padding_mask.shape
|
||||||
|
attn_offsets = attn_offsets.masked_fill(src_key_padding_mask, 0.0) # 0 at padding positions
|
||||||
|
# (seq_len, batch, 1)
|
||||||
|
attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(-1)
|
||||||
|
# (kernel_size, new_batch_size)
|
||||||
|
attn_offsets = unfold(
|
||||||
|
attn_offsets, pad_len, num_blocks,
|
||||||
|
kernel=kernel_size, stride=block_size, padding=block_pad,
|
||||||
|
).squeeze(-1)
|
||||||
|
|
||||||
|
# Used for the blocks are all padding
|
||||||
|
all_pad_mask = (attn_offsets.sum(dim=0, keepdim=True) == 0) # (1, new_batch_size)
|
||||||
|
all_pad_mask = all_pad_mask.unsqueeze(-1).unsqueeze(-1) # (1, new_batch_size, 1, 1)
|
||||||
|
|
||||||
|
# (new_batch_size, kernel_size)
|
||||||
|
src_key_padding_mask = (attn_offsets == 0).transpose(0, 1)
|
||||||
|
|
||||||
|
attn_offsets = 1 - attn_offsets # 1 at padding positions
|
||||||
|
attn_offsets[attn_offsets != 0] = -1000
|
||||||
|
|
||||||
|
# (1, new_batch_size, 1, kernel)
|
||||||
|
attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(1).unsqueeze(0)
|
||||||
|
|
||||||
|
# feature_mask: (1, batch_size, channel)
|
||||||
|
if isinstance(feature_mask, Tensor):
|
||||||
|
feature_mask = feature_mask.unsqueeze(2).expand(-1, -1, num_blocks, -1)
|
||||||
|
# now (kernel_size, batch_size, num_blocks, channel)
|
||||||
|
feature_mask = feature_mask.reshape(1, batch_size * num_blocks, channel)
|
||||||
|
else:
|
||||||
|
block_size = 0
|
||||||
|
|
||||||
|
# Used to mask out the padding positions
|
||||||
|
attn_offsets = torch.zeros(batch_size, seq_len, device=src.device)
|
||||||
|
if src_key_padding_mask is not None:
|
||||||
|
assert src_key_padding_mask.shape == (batch_size, seq_len), src_key_padding_mask.shape
|
||||||
|
attn_offsets = attn_offsets.masked_fill(src_key_padding_mask, -1000) # 0 at padding positions
|
||||||
|
# (1, batch_size, 1, seq_len)
|
||||||
|
attn_offsets = attn_offsets.unsqueeze(1).unsqueeze(0)
|
||||||
|
|
||||||
|
all_pad_mask = None
|
||||||
|
|
||||||
pos_emb = self.encoder_pos(src)
|
pos_emb = self.encoder_pos(src)
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
@ -978,12 +1060,29 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
pos_emb,
|
pos_emb,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
|
attn_offsets=attn_offsets,
|
||||||
|
all_pad_mask=all_pad_mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||||
output = output * feature_mask
|
output = output * feature_mask
|
||||||
|
|
||||||
|
if seq_len > max_block_size:
|
||||||
|
# overlap-and-add
|
||||||
|
output = fold(
|
||||||
|
output, seq_len, pad_len, num_blocks,
|
||||||
|
kernel=kernel_size, stride=block_size, padding=block_pad
|
||||||
|
) # (seq_len, batch_size, channel)
|
||||||
|
mask = torch.ones(
|
||||||
|
kernel_size, batch_size * num_blocks, 1, device=src.device,
|
||||||
|
)
|
||||||
|
mask = fold(
|
||||||
|
mask, seq_len, pad_len, num_blocks,
|
||||||
|
kernel=kernel_size, stride=block_size, padding=block_pad
|
||||||
|
) # (seq_len, batch_size, 1)
|
||||||
|
output = output / mask
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def streaming_forward(
|
def streaming_forward(
|
||||||
@ -1314,9 +1413,9 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
self.length_factor = length_factor
|
self.length_factor = length_factor
|
||||||
self.extend_pe(torch.tensor(0.0).expand(max_len))
|
self.extend_pe(torch.tensor(0.0).expand(max_len))
|
||||||
|
|
||||||
def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
|
def extend_pe(self, x: Tensor) -> None:
|
||||||
"""Reset the positional encodings."""
|
"""Reset the positional encodings."""
|
||||||
T = x.size(0) + left_context_len
|
T = x.size(0)
|
||||||
|
|
||||||
if self.pe is not None:
|
if self.pe is not None:
|
||||||
# self.pe contains both positive and negative parts
|
# self.pe contains both positive and negative parts
|
||||||
@ -1361,25 +1460,24 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
|
|
||||||
self.pe = pe.to(dtype=x.dtype)
|
self.pe = pe.to(dtype=x.dtype)
|
||||||
|
|
||||||
def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
|
def forward(self, x: Tensor, rel_pos: int = 0) -> Tensor:
|
||||||
"""Create positional encoding.
|
"""Create positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (Tensor): Input tensor (time, batch, `*`).
|
x (Tensor): Input tensor (time, batch, `*`).
|
||||||
left_context_len: (int): Length of cached left context.
|
block_size (int):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
|
positional embedding, of shape (1, 2*time-1, `*`) or (1, 4*block_size-1, `*`).
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x, left_context_len)
|
self.extend_pe(x)
|
||||||
x_size_left = x.size(0) + left_context_len
|
if rel_pos == 0:
|
||||||
# length of positive side: x.size(0) + left_context_len
|
rel_pos = x.size(0)
|
||||||
# length of negative side: x.size(0)
|
|
||||||
pos_emb = self.pe[
|
pos_emb = self.pe[
|
||||||
self.pe.size(0) // 2
|
self.pe.size(0) // 2
|
||||||
- x_size_left
|
- rel_pos
|
||||||
+ 1 : self.pe.size(0) // 2 # noqa E203
|
+ 1 : self.pe.size(0) // 2 # noqa E203
|
||||||
+ x.size(0),
|
+ rel_pos,
|
||||||
:
|
:
|
||||||
]
|
]
|
||||||
pos_emb = pos_emb.unsqueeze(0)
|
pos_emb = pos_emb.unsqueeze(0)
|
||||||
@ -1472,7 +1570,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
attn_offsets: Optional[Tensor] = None,
|
||||||
|
all_pad_mask: Optional[Tensor] = None,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""
|
r"""
|
||||||
@ -1580,6 +1679,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
|
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
|
assert attn_mask is None
|
||||||
assert attn_mask.dtype == torch.bool
|
assert attn_mask.dtype == torch.bool
|
||||||
# use -1000 to avoid nan's where attn_mask and key_padding_mask make
|
# use -1000 to avoid nan's where attn_mask and key_padding_mask make
|
||||||
# all scores zero. It's important that this be large enough that exp(-1000)
|
# all scores zero. It's important that this be large enough that exp(-1000)
|
||||||
@ -1587,12 +1687,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
# compares the final weights with zero.
|
# compares the final weights with zero.
|
||||||
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
|
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if attn_offsets is not None:
|
||||||
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
|
# attn_offsets: (1, batch_size, 1, seq_len)
|
||||||
attn_scores = attn_scores.masked_fill(
|
# or (1, new_batch_size, 1, kernel)
|
||||||
key_padding_mask.unsqueeze(1),
|
attn_scores = attn_scores + attn_offsets
|
||||||
-1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
# We use our own version of softmax, defined in scaling.py, which should
|
# We use our own version of softmax, defined in scaling.py, which should
|
||||||
# save a little of the memory used in backprop by, if we are in
|
# save a little of the memory used in backprop by, if we are in
|
||||||
@ -1600,6 +1698,189 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
# half-precision output for backprop purposes.
|
# half-precision output for backprop purposes.
|
||||||
attn_weights = softmax(attn_scores, dim=-1)
|
attn_weights = softmax(attn_scores, dim=-1)
|
||||||
|
|
||||||
|
if all_pad_mask is not None:
|
||||||
|
# For the blocks are all padding
|
||||||
|
# all_pad_mask: (1, new_batch_size, 1, 1)
|
||||||
|
attn_weights = attn_weights.masked_fill(all_pad_mask, 0.0)
|
||||||
|
|
||||||
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
|
pass
|
||||||
|
elif random.random() < 0.001 and not self.training:
|
||||||
|
self._print_attn_entropy(attn_weights)
|
||||||
|
|
||||||
|
attn_weights = nn.functional.dropout(
|
||||||
|
attn_weights, p=self.dropout, training=self.training
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_weights
|
||||||
|
|
||||||
|
def forward_block(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
block_size: int,
|
||||||
|
block_pad: int,
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
x: input of shape (seq_len, batch_size, embed_dim)
|
||||||
|
pos_emb: Positional embedding tensor, of shape (1, 4*block_size-1, pos_dim)
|
||||||
|
block_size: size of block
|
||||||
|
block_pad: pad size at each side of block
|
||||||
|
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
|
||||||
|
are True in this mask will be ignored as sources in the attention weighting.
|
||||||
|
attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
|
||||||
|
interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
|
||||||
|
saying which positions are allowed to attend to which other positions.
|
||||||
|
Returns:
|
||||||
|
a tensor of attention weights, of shape (hum_heads, batch_size * num_blocks, block_size, block_size * 3)
|
||||||
|
interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len),
|
||||||
|
where num_blocks = (seq_len + block_size - 1) // block_size.
|
||||||
|
"""
|
||||||
|
assert attn_mask is None, "Not supported yet"
|
||||||
|
|
||||||
|
x = self.in_proj(x)
|
||||||
|
query_head_dim = self.query_head_dim
|
||||||
|
pos_head_dim = self.pos_head_dim
|
||||||
|
num_heads = self.num_heads
|
||||||
|
|
||||||
|
seq_len, batch_size, _ = x.shape
|
||||||
|
|
||||||
|
query_dim = query_head_dim * num_heads
|
||||||
|
|
||||||
|
# self-attention
|
||||||
|
q = x[...,0:query_dim]
|
||||||
|
k = x[...,query_dim:2*query_dim]
|
||||||
|
# p is the position-encoding query
|
||||||
|
p = x[...,2*query_dim:]
|
||||||
|
assert p.shape[-1] == num_heads * pos_head_dim
|
||||||
|
|
||||||
|
q = self.copy_query(q) # for diagnostics only, does nothing.
|
||||||
|
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
|
||||||
|
p = self.copy_pos_query(p) # for diagnostics only, does nothing.
|
||||||
|
|
||||||
|
# divide into blocks by unfold function
|
||||||
|
num_blocks = (seq_len + block_size - 1) // block_size
|
||||||
|
pad_len = num_blocks * block_size - seq_len
|
||||||
|
|
||||||
|
# (kernel, batch_size * num_blocks, channel)
|
||||||
|
q_blocks = unfold(q, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0)
|
||||||
|
p_blocks = unfold(p, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0)
|
||||||
|
k_blocks = unfold(k, pad_len, num_blocks, kernel=block_size + 2 * block_pad, stride=block_size, padding=block_pad)
|
||||||
|
|
||||||
|
# time1 refers to target, time2 refers to source.
|
||||||
|
time1 = q_blocks.size(0)
|
||||||
|
time2 = k_blocks.size(0)
|
||||||
|
new_batch_size = batch_size * num_blocks
|
||||||
|
|
||||||
|
q_blocks = q_blocks.reshape(time1, new_batch_size, num_heads, query_head_dim)
|
||||||
|
p_blocks = p_blocks.reshape(time1, new_batch_size, num_heads, pos_head_dim)
|
||||||
|
k_blocks = k_blocks.reshape(time2, new_batch_size, num_heads, query_head_dim)
|
||||||
|
|
||||||
|
q_blocks = q_blocks.permute(2, 1, 0, 3) # (head, new_batch, time1, query_head_dim)
|
||||||
|
p_blocks = p_blocks.permute(2, 1, 0, 3) # (head, new_batch, time1, pos_head_dim)
|
||||||
|
k_blocks = k_blocks.permute(2, 1, 3, 0) # (head, new_batch, d_k, time2)
|
||||||
|
|
||||||
|
# (head, new_batch, time1, time2)
|
||||||
|
attn_scores = torch.matmul(q_blocks, k_blocks)
|
||||||
|
|
||||||
|
use_pos_scores = False
|
||||||
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
|
# We can't put random.random() in the same line
|
||||||
|
use_pos_scores = True
|
||||||
|
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||||
|
use_pos_scores = True
|
||||||
|
|
||||||
|
if use_pos_scores:
|
||||||
|
pos_emb = self.linear_pos(pos_emb)
|
||||||
|
pos_emb = pos_emb.reshape(1, time1 + time2 - 1, num_heads, pos_head_dim).permute(2, 0, 3, 1)
|
||||||
|
# pos shape now: (head, 1, pos_dim, time1+time2-1)
|
||||||
|
|
||||||
|
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, time1+time2-1) -> (head, batch, time1, time1+time2-1)
|
||||||
|
# [where seq_len2 represents relative position.]
|
||||||
|
pos_scores = torch.matmul(p_blocks, pos_emb)
|
||||||
|
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||||
|
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||||
|
# not, but let this code define which way round it is supposed to be.
|
||||||
|
pos_scores = pos_scores.as_strided((num_heads, new_batch_size, time1, time2),
|
||||||
|
(pos_scores.stride(0),
|
||||||
|
pos_scores.stride(1),
|
||||||
|
pos_scores.stride(2)-pos_scores.stride(3),
|
||||||
|
pos_scores.stride(3)),
|
||||||
|
storage_offset=pos_scores.stride(3) * (time1 - 1))
|
||||||
|
|
||||||
|
attn_scores = attn_scores + pos_scores
|
||||||
|
|
||||||
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
|
pass
|
||||||
|
elif self.training and random.random() < 0.1:
|
||||||
|
# This is a harder way of limiting the attention scores to not be
|
||||||
|
# too large. It incurs a penalty if any of them has an absolute
|
||||||
|
# value greater than 50.0. this should be outside the normal range
|
||||||
|
# of the attention scores. We use this mechanism instead of, say,
|
||||||
|
# something added to the loss function involving the entropy,
|
||||||
|
# because once the entropy gets very small gradients through the
|
||||||
|
# softmax can become very small, and we'd get zero derivatives. The
|
||||||
|
# choices of 1.0e-04 as the scale on the penalty makes this
|
||||||
|
# mechanism vulnerable to the absolute scale of the loss function,
|
||||||
|
# but we view this as a failsafe to avoid "implausible" parameter
|
||||||
|
# values rather than a regularization method that should be active
|
||||||
|
# under normal circumstances.
|
||||||
|
attn_scores = penalize_abs_values_gt(attn_scores,
|
||||||
|
limit=25.0,
|
||||||
|
penalty=1.0e-04,
|
||||||
|
name=self.name)
|
||||||
|
|
||||||
|
assert attn_scores.shape == (num_heads, new_batch_size, time1, time2)
|
||||||
|
|
||||||
|
assert attn_mask is None
|
||||||
|
if attn_mask is not None:
|
||||||
|
# TODO:
|
||||||
|
assert attn_mask.dtype == torch.bool
|
||||||
|
# use -1000 to avoid nan's where attn_mask and key_padding_mask make
|
||||||
|
# all scores zero. It's important that this be large enough that exp(-1000)
|
||||||
|
# is exactly zero, for reasons related to const_attention_rate, it
|
||||||
|
# compares the final weights with zero.
|
||||||
|
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
|
||||||
|
|
||||||
|
# Used to mask out the padding positions
|
||||||
|
attn_offsets = torch.ones(batch_size, seq_len, device=x.device)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
|
||||||
|
attn_offsets = attn_offsets.masked_fill(key_padding_mask, 0.0) # 0 at padding positions
|
||||||
|
|
||||||
|
# (seq_len, batch, 1)
|
||||||
|
attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(-1)
|
||||||
|
# (time2, new_batch_size)
|
||||||
|
attn_offsets = unfold(
|
||||||
|
attn_offsets, pad_len, num_blocks,
|
||||||
|
kernel=block_size + 2 * block_pad, stride=block_size, padding=block_pad,
|
||||||
|
).squeeze(-1)
|
||||||
|
|
||||||
|
# Used for the blocks are all padding
|
||||||
|
all_pad_mask = (attn_offsets.sum(dim=0, keepdim=True) == 0) # (1, new_batch_size)
|
||||||
|
all_pad_mask = all_pad_mask.unsqueeze(-1).unsqueeze(-1) # (1, new_batch_size, 1, 1)
|
||||||
|
|
||||||
|
attn_offsets = 1 - attn_offsets # 1 at padding positions
|
||||||
|
attn_offsets[attn_offsets != 0] = -1000
|
||||||
|
|
||||||
|
# (1, new_batch_size, 1, time2)
|
||||||
|
attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(1).unsqueeze(0)
|
||||||
|
|
||||||
|
attn_scores = attn_scores + attn_offsets
|
||||||
|
|
||||||
|
# We use our own version of softmax, defined in scaling.py, which should
|
||||||
|
# save a little of the memory used in backprop by, if we are in
|
||||||
|
# automatic mixed precision mode (amp / autocast), by only storing the
|
||||||
|
# half-precision output for backprop purposes.
|
||||||
|
attn_weights = softmax(attn_scores, dim=-1)
|
||||||
|
|
||||||
|
# For the blocks are all padding
|
||||||
|
attn_weights = attn_weights.masked_fill(all_pad_mask, 0.0)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
pass
|
pass
|
||||||
elif random.random() < 0.001 and not self.training:
|
elif random.random() < 0.001 and not self.training:
|
||||||
@ -1678,7 +1959,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||||
# [where seq_len2 represents relative position.]
|
# [where seq_len2 represents relative position.]
|
||||||
pos_scores = torch.matmul(p, pos_emb)
|
pos_scores = torch.matmul(p, pos_emb)
|
||||||
|
|
||||||
if torch.jit.is_tracing():
|
if torch.jit.is_tracing():
|
||||||
(num_heads, batch_size, time1, n) = pos_scores.shape
|
(num_heads, batch_size, time1, n) = pos_scores.shape
|
||||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||||
@ -1794,6 +2075,63 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def forward_block(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
attn_weights: Tensor,
|
||||||
|
block_size: int,
|
||||||
|
block_pad: int,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: input tensor, of shape (seq_len, batch_size, embed_dim)
|
||||||
|
attn_weights: a tensor of attention weights, of shape
|
||||||
|
(hum_heads, batch_size * num_blocks, block_size, block_size * 3)
|
||||||
|
interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len),
|
||||||
|
where num_blocks = (seq_len + block_size - 1) // block_size.
|
||||||
|
Expect attn_weights.sum(dim=-1) == 1.
|
||||||
|
block_size: size of block
|
||||||
|
block_pad: pad size at each side of block
|
||||||
|
Returns:
|
||||||
|
a tensor with the same shape as x.
|
||||||
|
"""
|
||||||
|
(seq_len, batch_size, embed_dim) = x.shape
|
||||||
|
num_heads = attn_weights.shape[0]
|
||||||
|
|
||||||
|
# divide into blocks by unfold function
|
||||||
|
num_blocks = (seq_len + block_size - 1) // block_size
|
||||||
|
pad_len = num_blocks * block_size - seq_len
|
||||||
|
new_batch_size = batch_size * num_blocks
|
||||||
|
time1 = block_size # target length
|
||||||
|
time2 = block_size + 2 * block_pad # source length
|
||||||
|
|
||||||
|
assert attn_weights.shape == (num_heads, new_batch_size, time1, time2)
|
||||||
|
|
||||||
|
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
|
||||||
|
|
||||||
|
# (time2, new_batch_size, channel)
|
||||||
|
x_blocks = unfold(x, pad_len, num_blocks, kernel=time2, stride=block_size, padding=block_pad)
|
||||||
|
|
||||||
|
x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
||||||
|
# now x: (num_heads, new_batch_size, time2, value_head_dim)
|
||||||
|
value_head_dim = x_blocks.shape[-1]
|
||||||
|
|
||||||
|
# todo: see whether there is benefit in overriding matmul
|
||||||
|
x = torch.matmul(attn_weights, x_blocks)
|
||||||
|
# v: (num_heads, new_batch_size, time1, value_head_dim)
|
||||||
|
|
||||||
|
x = x.reshape(num_heads, batch_size, num_blocks, time1, value_head_dim)
|
||||||
|
x = x.permute(2, 3, 1, 0, 4).contiguous().view(
|
||||||
|
num_blocks * time1, batch_size, num_heads * value_head_dim)
|
||||||
|
|
||||||
|
x = x[:seq_len] # (seq_len, batch_size, value_dim)
|
||||||
|
|
||||||
|
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
|
||||||
|
x = self.out_proj(x)
|
||||||
|
x = self.whiten(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
def streaming_forward(
|
def streaming_forward(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@ -1981,6 +2319,78 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
x = self.whiten2(x)
|
x = self.whiten2(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def forward_block(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
attn_weights: Tensor,
|
||||||
|
block_size: int,
|
||||||
|
block_pad: int,
|
||||||
|
) -> Tensor:
|
||||||
|
""".
|
||||||
|
Args:
|
||||||
|
x: a Tensor of shape (seq_len, batch_size, num_channels)
|
||||||
|
attn_weights: a tensor of attention weights, of shape
|
||||||
|
(hum_heads, batch_size * num_blocks, block_size, block_size * 3)
|
||||||
|
interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len),
|
||||||
|
where num_blocks = (seq_len + block_size - 1) // block_size.
|
||||||
|
Expect attn_weights.sum(dim=-1) == 1.
|
||||||
|
block_size: size of block
|
||||||
|
block_pad: pad size at each side of block
|
||||||
|
Returns:
|
||||||
|
a Tensor with the same shape as x
|
||||||
|
"""
|
||||||
|
x = self.in_proj(x)
|
||||||
|
|
||||||
|
(seq_len, batch_size, _) = x.shape
|
||||||
|
hidden_channels = self.hidden_channels
|
||||||
|
|
||||||
|
s, x, y = x.chunk(3, dim=-1)
|
||||||
|
|
||||||
|
# s will go through tanh.
|
||||||
|
|
||||||
|
s = self.balancer(s)
|
||||||
|
s = self.tanh(s)
|
||||||
|
|
||||||
|
s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
|
||||||
|
x = self.whiten1(x)
|
||||||
|
x = x * s
|
||||||
|
x = self.identity1(x) # diagnostics only, it's the identity.
|
||||||
|
|
||||||
|
(seq_len, batch_size, embed_dim) = x.shape
|
||||||
|
num_heads = attn_weights.shape[0]
|
||||||
|
|
||||||
|
# divide into blocks by unfold function
|
||||||
|
num_blocks = (seq_len + block_size - 1) // block_size
|
||||||
|
pad_len = num_blocks * block_size - seq_len
|
||||||
|
new_batch_size = batch_size * num_blocks
|
||||||
|
time1 = block_size # target length
|
||||||
|
time2 = block_size + 2 * block_pad # source length
|
||||||
|
|
||||||
|
assert attn_weights.shape == (num_heads, new_batch_size, time1, time2)
|
||||||
|
|
||||||
|
# (time2, new_batch_size, channel)
|
||||||
|
x_blocks = unfold(x, pad_len, num_blocks, kernel=time2, stride=block_size, padding=block_pad)
|
||||||
|
|
||||||
|
x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
||||||
|
# now x: (num_heads, new_batch_size, time2, head_dim)
|
||||||
|
|
||||||
|
x = torch.matmul(attn_weights, x_blocks)
|
||||||
|
# now x: (num_heads, new_batch_size, time1, head_dim)
|
||||||
|
|
||||||
|
x = x.reshape(num_heads, batch_size, num_blocks, time1, -1)
|
||||||
|
x = x.permute(2, 3, 1, 0, 4).contiguous().view(
|
||||||
|
num_blocks * time1, batch_size, embed_dim)
|
||||||
|
|
||||||
|
x = x[:seq_len] # (seq_len, batch_size, embed_dim)
|
||||||
|
|
||||||
|
y = self.identity2(y)
|
||||||
|
x = x * y
|
||||||
|
x = self.identity3(x)
|
||||||
|
|
||||||
|
x = self.out_proj(x)
|
||||||
|
x = self.whiten2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
def streaming_forward(
|
def streaming_forward(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@ -2220,30 +2630,38 @@ class ScalarMultiply(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def _test_zipformer_main(causal: bool = False):
|
def _test_zipformer_main(causal: bool = False):
|
||||||
batch_size = 5
|
|
||||||
seq_len = 20
|
|
||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
|
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
c = Zipformer2(
|
c = Zipformer2(
|
||||||
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
||||||
|
downsampling_factor=(1, 2),
|
||||||
|
max_block_size=14,
|
||||||
|
block_pad=2,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
chunk_size=(4,) if causal else (-1,),
|
chunk_size=(4,) if causal else (-1,),
|
||||||
left_context_frames=(64,)
|
left_context_frames=(64,)
|
||||||
)
|
)
|
||||||
batch_size = 5
|
batch_size = 2
|
||||||
seq_len = 20
|
seq_len = 27
|
||||||
|
|
||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
f = c(
|
x = torch.randn(seq_len, batch_size, 64)
|
||||||
torch.randn(seq_len, batch_size, 64),
|
lengths = torch.full((batch_size,), seq_len, dtype=torch.int64)
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
lengths[-1] = 1
|
||||||
)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
|
f = c(x, lengths, src_key_padding_mask)
|
||||||
f[0].sum().backward()
|
f[0].sum().backward()
|
||||||
c.eval()
|
c.eval()
|
||||||
f = c(
|
|
||||||
torch.randn(seq_len, batch_size, 64),
|
x = torch.randn(seq_len, batch_size, 64)
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
lengths = torch.full((batch_size,), seq_len, dtype=torch.int64)
|
||||||
)
|
lengths[-1] = seq_len - 2
|
||||||
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
|
f = c(x, lengths, src_key_padding_mask)
|
||||||
f # to remove flake8 warnings
|
f # to remove flake8 warnings
|
||||||
|
print(f[0].sum())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -2251,4 +2669,4 @@ if __name__ == "__main__":
|
|||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
_test_zipformer_main(False)
|
_test_zipformer_main(False)
|
||||||
_test_zipformer_main(True)
|
# _test_zipformer_main(True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user