From 80a14f93d3d7e3de50c3623cc055333128bb751e Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 20 Jul 2023 19:38:03 +0800 Subject: [PATCH 1/5] Use block-wise attention --- egs/librispeech/ASR/zipformer/scaling.py | 27 +++ egs/librispeech/ASR/zipformer/train.py | 8 + egs/librispeech/ASR/zipformer/zipformer.py | 267 ++++++++++++++------- 3 files changed, 215 insertions(+), 87 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 4ee7b7826..6c315b8e8 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1602,6 +1602,33 @@ def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: return torch.cat((x, zeros), dim=-1) +def unfold( + x: Tensor, x_pad: int, num_blocks: int, kernel: int, stride: int, padding: int +) -> Tensor: + """ + Args: + x: input of shape (seq_len, batch_size, channel) + Returns: + blocks: (kernel, batch_size * num_blocks, channel) + """ + seq_len, batch_size, channel = x.size() + x = x.permute(1, 2, 0) # (B, D, T) + + x = nn.functional.pad(x, pad=(0, x_pad), value=0.0) + + blocks = nn.functional.unfold( + x.unsqueeze(-1), + kernel_size=(kernel, 1), + padding=(padding, 0), + stride=(stride, 1), + ) # (B, C * kernel, num_blocks) + blocks = blocks.reshape(batch_size, channel, kernel, num_blocks) + blocks = blocks.permute(2, 0, 3, 1) + blocks = blocks.reshape(kernel, batch_size * num_blocks, channel) + + return blocks + + def _test_whiten(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"_test_whiten(): proportion = {proportion}") diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index bc3e9c1ba..5a8dae619 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -187,6 +187,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Positional-encoding embedding dimension", ) + parser.add_argument( + "--block-size", + type=int, + default="32", + help="Block size used in block-wise attention", + ) + parser.add_argument( "--encoder-unmasked-dim", type=str, @@ -574,6 +581,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: num_heads=_to_int_tuple(params.num_heads), feedforward_dim=_to_int_tuple(params.feedforward_dim), cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + block_size=params.block_size, dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), warmup_batches=4000.0, causal=params.causal, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7d98dbeb1..d12b9f22b 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -39,6 +39,7 @@ from scaling import ( FloatLike, limit_param_value, convert_num_channels, + unfold, ) from torch import Tensor, nn @@ -105,6 +106,7 @@ class Zipformer2(EncoderInterface): feedforward_dim: Union[int, Tuple[int]] = 1536, cnn_module_kernel: Union[int, Tuple[int]] = 31, pos_dim: int = 192, + block_size: int = 32, dropout: FloatLike = None, # see code below for default warmup_batches: float = 4000.0, causal: bool = False, @@ -140,6 +142,7 @@ class Zipformer2(EncoderInterface): self.num_heads = num_heads = _to_tuple(num_heads) feedforward_dim = _to_tuple(feedforward_dim) self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + self.block_size = block_size self.causal = causal self.chunk_size = chunk_size @@ -153,6 +156,7 @@ class Zipformer2(EncoderInterface): num_encoders = len(downsampling_factor) for i in range(num_encoders): + ds = downsampling_factor[i] encoder_layer = Zipformer2EncoderLayer( embed_dim=encoder_dim[i], @@ -164,6 +168,7 @@ class Zipformer2(EncoderInterface): feedforward_dim=feedforward_dim[i], dropout=dropout, cnn_module_kernel=cnn_module_kernel[i], + block_size=block_size // ds, causal=causal, ) @@ -173,13 +178,14 @@ class Zipformer2(EncoderInterface): encoder_layer, num_encoder_layers[i], pos_dim=pos_dim, + block_size=block_size // ds, dropout=dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), ) - if downsampling_factor[i] != 1: + if ds != 1: encoder = DownsampledZipformer2Encoder( encoder, dim=encoder_dim[i], @@ -536,6 +542,7 @@ class Zipformer2EncoderLayer(nn.Module): feedforward_dim: int, dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, + block_size: int = 32, causal: bool = False, attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), @@ -569,14 +576,14 @@ class Zipformer2EncoderLayer(nn.Module): self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, pos_dim=pos_dim, num_heads=num_heads, query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, - dropout=0.0, + block_size=block_size, dropout=0.0, ) self.self_attn1 = SelfAttention(embed_dim, num_heads, - value_head_dim) + value_head_dim, block_size=block_size) self.self_attn2 = SelfAttention(embed_dim, num_heads, - value_head_dim) + value_head_dim, block_size=block_size) self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, @@ -591,7 +598,8 @@ class Zipformer2EncoderLayer(nn.Module): dropout) self.nonlin_attention = NonlinAttention(embed_dim, - hidden_channels=3 * embed_dim // 4) + hidden_channels=3 * embed_dim // 4, + block_size=block_size) self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, @@ -917,6 +925,7 @@ class Zipformer2Encoder(nn.Module): encoder_layer: nn.Module, num_layers: int, pos_dim: int, + block_size: int, dropout: float, warmup_begin: float, warmup_end: float, @@ -931,6 +940,7 @@ class Zipformer2Encoder(nn.Module): [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) self.num_layers = num_layers + self.block_size = block_size assert 0 <= warmup_begin <= warmup_end @@ -966,7 +976,7 @@ class Zipformer2Encoder(nn.Module): Returns: a Tensor with the same shape as src. """ - pos_emb = self.encoder_pos(src) + pos_emb = self.encoder_pos(src, block_size=self.block_size) output = src if not torch.jit.is_scripting() and not torch.jit.is_tracing(): @@ -1314,9 +1324,9 @@ class CompactRelPositionalEncoding(torch.nn.Module): self.length_factor = length_factor self.extend_pe(torch.tensor(0.0).expand(max_len)) - def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: + def extend_pe(self, x: Tensor) -> None: """Reset the positional encodings.""" - T = x.size(0) + left_context_len + T = x.size(0) if self.pe is not None: # self.pe contains both positive and negative parts @@ -1361,25 +1371,25 @@ class CompactRelPositionalEncoding(torch.nn.Module): self.pe = pe.to(dtype=x.dtype) - def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + def forward(self, x: Tensor, block_size: int = 0) -> Tensor: """Create positional encoding. Args: x (Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. + block_size (int): Returns: - positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + positional embedding, of shape (1, 2*time-1, `*`) or (1, 4*block_size-1, `*`). """ - self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - # length of positive side: x.size(0) + left_context_len - # length of negative side: x.size(0) + self.extend_pe(x) + rel_pos = 2 * block_size if block_size != 0 else x.size(0) + # length of positive side: 2 * block_size + # length of negative side: 2 * block_size pos_emb = self.pe[ self.pe.size(0) // 2 - - x_size_left + - rel_pos + 1 : self.pe.size(0) // 2 # noqa E203 - + x.size(0), + + rel_pos, : ] pos_emb = pos_emb.unsqueeze(0) @@ -1413,6 +1423,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): num_heads: int, query_head_dim: int, pos_head_dim: int, + block_size: int, dropout: float = 0.0, pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)) @@ -1422,6 +1433,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self.num_heads = num_heads self.query_head_dim = query_head_dim self.pos_head_dim = pos_head_dim + self.block_size = block_size self.dropout = dropout self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) self.name = None # will be overwritten in training code; for diagnostics. @@ -1478,16 +1490,19 @@ class RelPositionMultiheadAttentionWeights(nn.Module): r""" Args: x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) + pos_emb: Positional embedding tensor, of shape (1, 4*block_size-1, pos_dim) key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that are True in this mask will be ignored as sources in the attention weighting. attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), interpreted as ([batch_size,] tgt_seq_len, src_seq_len) saying which positions are allowed to attend to which other positions. Returns: - a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + a tensor of attention weights, of shape (hum_heads, batch_size * num_blocks, block_size, block_size * 3) + interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len), + where num_blocks = (seq_len + block_size - 1) // block_size. """ + assert attn_mask is None, "Not supported yet" + x = self.in_proj(x) query_head_dim = self.query_head_dim pos_head_dim = self.pos_head_dim @@ -1508,16 +1523,31 @@ class RelPositionMultiheadAttentionWeights(nn.Module): k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. p = self.copy_pos_query(p) # for diagnostics only, does nothing. - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) - k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + # divide into blocks by unfold function + block_size = self.block_size + num_blocks = (seq_len + block_size - 1) // block_size + pad_len = num_blocks * block_size - seq_len + + # (kernel, batch_size * num_blocks, channel) + q_blocks = unfold(q, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0) + p_blocks = unfold(p, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0) + k_blocks = unfold(k, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size) # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + time1 = q_blocks.size(0) + time2 = k_blocks.size(0) + new_batch_size = batch_size * num_blocks - attn_scores = torch.matmul(q, k) + q_blocks = q_blocks.reshape(time1, new_batch_size, num_heads, query_head_dim) + p_blocks = p_blocks.reshape(time1, new_batch_size, num_heads, pos_head_dim) + k_blocks = k_blocks.reshape(time2, new_batch_size, num_heads, query_head_dim) + + q_blocks = q_blocks.permute(2, 1, 0, 3) # (head, new_batch, time1, query_head_dim) + p_blocks = p_blocks.permute(2, 1, 0, 3) # (head, new_batch, time1, pos_head_dim) + k_blocks = k_blocks.permute(2, 1, 3, 0) # (head, new_batch, d_k, time2) + + # (head, new_batch, time1, time2) + attn_scores = torch.matmul(q_blocks, k_blocks) use_pos_scores = False if torch.jit.is_scripting() or torch.jit.is_tracing(): @@ -1528,32 +1558,21 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if use_pos_scores: pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + pos_emb = pos_emb.reshape(1, time1 + time2 - 1, num_heads, pos_head_dim).permute(2, 0, 3, 1) + # pos shape now: (head, 1, pos_dim, time1+time2-1) - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, time1+time2-1) -> (head, batch, time1, time1+time2-1) # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) + pos_scores = torch.matmul(p_blocks, pos_emb) # the following .as_strided() expression converts the last axis of pos_scores from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - if torch.jit.is_tracing(): - (num_heads, batch_size, time1, n) = pos_scores.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(seq_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_scores = pos_scores.reshape(-1, n) - pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) - else: - pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len), - (pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2)-pos_scores.stride(3), - pos_scores.stride(3)), - storage_offset=pos_scores.stride(3) * (seq_len - 1)) + pos_scores = pos_scores.as_strided((num_heads, new_batch_size, time1, time2), + (pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2)-pos_scores.stride(3), + pos_scores.stride(3)), + storage_offset=pos_scores.stride(3) * (time1 - 1)) attn_scores = attn_scores + pos_scores @@ -1577,9 +1596,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module): penalty=1.0e-04, name=self.name) - assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + assert attn_scores.shape == (num_heads, new_batch_size, time1, time2) if attn_mask is not None: + # TODO: assert attn_mask.dtype == torch.bool # use -1000 to avoid nan's where attn_mask and key_padding_mask make # all scores zero. It's important that this be large enough that exp(-1000) @@ -1587,12 +1607,31 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # compares the final weights with zero. attn_scores = attn_scores.masked_fill(attn_mask, -1000) - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) + assert key_padding_mask is not None + assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape + attn_offsets = (~key_padding_mask).float() # 0 at padding positions + + # (seq_len, batch, 1) + attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(-1) + # (time2, new_batch_size) + attn_offsets = unfold( + attn_offsets, pad_len, num_blocks, + kernel=block_size * 3, stride=block_size, padding=block_size, + ).squeeze(-1) + + # For the blocks are all padding + all_pad_mask = (attn_offsets.sum(dim=0, keepdim=True) == 0) # (1, new_batch_size) + all_pad_mask = all_pad_mask.unsqueeze(-1).unsqueeze(-1) # (1, new_batch_size, 1, 1) + + attn_offsets = 1 - attn_offsets # 1 at padding positions + # attn_offsets[attn_offsets != 0] = float("-inf") + attn_offsets[attn_offsets != 0] = -1000 + # attn_offsets = attn_offsets.masked_fill((attn_offsets != 0), -1000) + + # (1, new_batch_size, 1, time2) + attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(1).unsqueeze(0) + + attn_scores = attn_scores + attn_offsets # We use our own version of softmax, defined in scaling.py, which should # save a little of the memory used in backprop by, if we are in @@ -1600,6 +1639,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # half-precision output for backprop purposes. attn_weights = softmax(attn_scores, dim=-1) + # For the blocks are all padding + attn_weights = attn_weights.masked_fill(all_pad_mask, 0.0) + if torch.jit.is_scripting() or torch.jit.is_tracing(): pass elif random.random() < 0.001 and not self.training: @@ -1678,7 +1720,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # [where seq_len2 represents relative position.] pos_scores = torch.matmul(p, pos_emb) - + if torch.jit.is_tracing(): (num_heads, batch_size, time1, n) = pos_scores.shape rows = torch.arange(start=time1 - 1, end=-1, step=-1) @@ -1743,8 +1785,10 @@ class SelfAttention(nn.Module): embed_dim: int, num_heads: int, value_head_dim: int, + block_size: int, ) -> None: super().__init__() + self.block_size = block_size self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) @@ -1766,27 +1810,45 @@ class SelfAttention(nn.Module): """ Args: x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. + attn_weights: a tensor of attention weights, of shape + (hum_heads, batch_size * num_blocks, block_size, block_size * 3) + interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len), + where num_blocks = (seq_len + block_size - 1) // block_size. + Expect attn_weights.sum(dim=-1) == 1. Returns: a tensor with the same shape as x. """ (seq_len, batch_size, embed_dim) = x.shape num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + # divide into blocks by unfold function + block_size = self.block_size + num_blocks = (seq_len + block_size - 1) // block_size + pad_len = num_blocks * block_size - seq_len + new_batch_size = batch_size * num_blocks + time1 = block_size # target length + time2 = 3 * block_size # source length + + assert attn_weights.shape == (num_heads, new_batch_size, time1, time2) x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] + + # (time2, new_batch_size, channel) + x_blocks = unfold(x, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size) + + x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, new_batch_size, time2, value_head_dim) + value_head_dim = x_blocks.shape[-1] # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) + x = torch.matmul(attn_weights, x_blocks) + # v: (num_heads, new_batch_size, time1, value_head_dim) - x = x.permute(2, 1, 0, 3).contiguous().view( - seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(num_heads, batch_size, num_blocks, time1, value_head_dim) + x = x.permute(2, 3, 1, 0, 4).contiguous().view( + num_blocks * time1, batch_size, num_heads * value_head_dim) + + x = x[:seq_len] # (seq_len, batch_size, value_dim) # returned value is of shape (seq_len, batch_size, embed_dim), like the input. x = self.out_proj(x) @@ -1896,10 +1958,12 @@ class NonlinAttention(nn.Module): self, channels: int, hidden_channels: int, + block_size: int, ) -> None: super().__init__() self.hidden_channels = hidden_channels + self.block_size = block_size self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) @@ -1942,7 +2006,11 @@ class NonlinAttention(nn.Module): """. Args: x: a Tensor of shape (seq_len, batch_size, num_channels) -attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + attn_weights: a tensor of attention weights, of shape + (hum_heads, batch_size * num_blocks, block_size, block_size * 3) + interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len), + where num_blocks = (seq_len + block_size - 1) // block_size. + Expect attn_weights.sum(dim=-1) == 1. Returns: a Tensor with the same shape as x """ @@ -1965,13 +2033,31 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) (seq_len, batch_size, embed_dim) = x.shape num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = torch.matmul(attn_weights, x) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + # divide into blocks by unfold function + block_size = self.block_size + num_blocks = (seq_len + block_size - 1) // block_size + pad_len = num_blocks * block_size - seq_len + new_batch_size = batch_size * num_blocks + time1 = block_size # target length + time2 = 3 * block_size # source length + + assert attn_weights.shape == (num_heads, new_batch_size, time1, time2) + + # (time2, new_batch_size, channel) + x_blocks = unfold(x, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size) + + x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, new_batch_size, time2, head_dim) + + x = torch.matmul(attn_weights, x_blocks) + # now x: (num_heads, new_batch_size, time1, head_dim) + + x = x.reshape(num_heads, batch_size, num_blocks, time1, -1) + x = x.permute(2, 3, 1, 0, 4).contiguous().view( + num_blocks * time1, batch_size, embed_dim) + + x = x[:seq_len] # (seq_len, batch_size, embed_dim) y = self.identity2(y) x = x * y @@ -2220,30 +2306,37 @@ class ScalarMultiply(nn.Module): def _test_zipformer_main(causal: bool = False): - batch_size = 5 - seq_len = 20 # Just make sure the forward pass runs. + from icefall.utils import make_pad_mask + c = Zipformer2( encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), + downsampling_factor=(1, 2), + block_size=4, causal=causal, chunk_size=(4,) if causal else (-1,), left_context_frames=(64,) ) - batch_size = 5 - seq_len = 20 + batch_size = 2 + seq_len = 14 + # Just make sure the forward pass runs. - f = c( - torch.randn(seq_len, batch_size, 64), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) + x = torch.randn(seq_len, batch_size, 64) + lengths = torch.full((batch_size,), seq_len, dtype=torch.int64) + lengths[-1] = 1 + src_key_padding_mask = make_pad_mask(lengths) + f = c(x, lengths, src_key_padding_mask) f[0].sum().backward() c.eval() - f = c( - torch.randn(seq_len, batch_size, 64), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) + + x = torch.randn(seq_len, batch_size, 64) + lengths = torch.full((batch_size,), seq_len, dtype=torch.int64) + lengths[-1] = seq_len - 2 + src_key_padding_mask = make_pad_mask(lengths) + f = c(x, lengths, src_key_padding_mask) f # to remove flake8 warnings + print(f[0].sum()) if __name__ == "__main__": @@ -2251,4 +2344,4 @@ if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) _test_zipformer_main(False) - _test_zipformer_main(True) + # _test_zipformer_main(True) From 6aaa971b34251da78a8c71d7c0e8324894ed5ef2 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 21 Jul 2023 11:34:19 +0800 Subject: [PATCH 2/5] make block-size be a list --- egs/librispeech/ASR/zipformer/train.py | 6 +++--- egs/librispeech/ASR/zipformer/zipformer.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 5a8dae619..eabed65fb 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -189,9 +189,9 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--block-size", - type=int, + type=str, default="32", - help="Block size used in block-wise attention", + help="Block size used in block-wise attention; a single int or comma-separated list", ) parser.add_argument( @@ -581,7 +581,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: num_heads=_to_int_tuple(params.num_heads), feedforward_dim=_to_int_tuple(params.feedforward_dim), cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - block_size=params.block_size, + block_size=_to_int_tuple(params.block_size), dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), warmup_batches=4000.0, causal=params.causal, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d12b9f22b..0ca0fcaa4 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -106,7 +106,7 @@ class Zipformer2(EncoderInterface): feedforward_dim: Union[int, Tuple[int]] = 1536, cnn_module_kernel: Union[int, Tuple[int]] = 31, pos_dim: int = 192, - block_size: int = 32, + block_size: Union[int, Tuple[int]] = 32, dropout: FloatLike = None, # see code below for default warmup_batches: float = 4000.0, causal: bool = False, @@ -142,7 +142,7 @@ class Zipformer2(EncoderInterface): self.num_heads = num_heads = _to_tuple(num_heads) feedforward_dim = _to_tuple(feedforward_dim) self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) - self.block_size = block_size + self.block_size = block_size = _to_tuple(block_size) self.causal = causal self.chunk_size = chunk_size @@ -168,7 +168,7 @@ class Zipformer2(EncoderInterface): feedforward_dim=feedforward_dim[i], dropout=dropout, cnn_module_kernel=cnn_module_kernel[i], - block_size=block_size // ds, + block_size=block_size[i], causal=causal, ) @@ -178,7 +178,7 @@ class Zipformer2(EncoderInterface): encoder_layer, num_encoder_layers[i], pos_dim=pos_dim, - block_size=block_size // ds, + block_size=block_size[i], dropout=dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), From ee485c02fcff3898d28f36108136c356fbaef313 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 21 Jul 2023 15:38:22 +0800 Subject: [PATCH 3/5] modify attn_offsets --- egs/librispeech/ASR/zipformer/zipformer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 0ca0fcaa4..032262e76 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1598,6 +1598,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): assert attn_scores.shape == (num_heads, new_batch_size, time1, time2) + assert attn_mask is None if attn_mask is not None: # TODO: assert attn_mask.dtype == torch.bool @@ -1607,9 +1608,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # compares the final weights with zero. attn_scores = attn_scores.masked_fill(attn_mask, -1000) - assert key_padding_mask is not None - assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape - attn_offsets = (~key_padding_mask).float() # 0 at padding positions + # Used to mask out the padding positions + attn_offsets = torch.ones(batch_size, seq_len, device=x.device) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape + attn_offsets = attn_offsets.masked_fill(key_padding_mask, 0.0) # 0 at padding positions # (seq_len, batch, 1) attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(-1) @@ -1619,14 +1623,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module): kernel=block_size * 3, stride=block_size, padding=block_size, ).squeeze(-1) - # For the blocks are all padding + # Used for the blocks are all padding all_pad_mask = (attn_offsets.sum(dim=0, keepdim=True) == 0) # (1, new_batch_size) all_pad_mask = all_pad_mask.unsqueeze(-1).unsqueeze(-1) # (1, new_batch_size, 1, 1) attn_offsets = 1 - attn_offsets # 1 at padding positions - # attn_offsets[attn_offsets != 0] = float("-inf") attn_offsets[attn_offsets != 0] = -1000 - # attn_offsets = attn_offsets.masked_fill((attn_offsets != 0), -1000) # (1, new_batch_size, 1, time2) attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(1).unsqueeze(0) From 215541c7c509ba2c08fb3cb24c85acf65549ecb2 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 23 Jul 2023 16:12:57 +0800 Subject: [PATCH 4/5] Do block-wise attention when seq_len is larger than 512, with block_size <= 512 --- egs/librispeech/ASR/zipformer/train.py | 8 +- egs/librispeech/ASR/zipformer/zipformer.py | 357 ++++++++++++++++++--- 2 files changed, 318 insertions(+), 47 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index eabed65fb..9b2dd8a95 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -188,10 +188,10 @@ def add_model_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( - "--block-size", + "--max-block-size", type=str, - default="32", - help="Block size used in block-wise attention; a single int or comma-separated list", + default="512", + help="Max block size used in block-wise attention; a single int or comma-separated list", ) parser.add_argument( @@ -581,7 +581,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: num_heads=_to_int_tuple(params.num_heads), feedforward_dim=_to_int_tuple(params.feedforward_dim), cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - block_size=_to_int_tuple(params.block_size), + max_block_size=_to_int_tuple(params.max_block_size), dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), warmup_batches=4000.0, causal=params.causal, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 032262e76..7985a11fb 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -106,7 +106,8 @@ class Zipformer2(EncoderInterface): feedforward_dim: Union[int, Tuple[int]] = 1536, cnn_module_kernel: Union[int, Tuple[int]] = 31, pos_dim: int = 192, - block_size: Union[int, Tuple[int]] = 32, + max_block_size: Union[int, Tuple[int]] = 512, + block_pad: int = 16, dropout: FloatLike = None, # see code below for default warmup_batches: float = 4000.0, causal: bool = False, @@ -142,7 +143,7 @@ class Zipformer2(EncoderInterface): self.num_heads = num_heads = _to_tuple(num_heads) feedforward_dim = _to_tuple(feedforward_dim) self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) - self.block_size = block_size = _to_tuple(block_size) + self.max_block_size = max_block_size = _to_tuple(max_block_size) self.causal = causal self.chunk_size = chunk_size @@ -168,7 +169,6 @@ class Zipformer2(EncoderInterface): feedforward_dim=feedforward_dim[i], dropout=dropout, cnn_module_kernel=cnn_module_kernel[i], - block_size=block_size[i], causal=causal, ) @@ -178,7 +178,8 @@ class Zipformer2(EncoderInterface): encoder_layer, num_encoder_layers[i], pos_dim=pos_dim, - block_size=block_size[i], + max_block_size=max_block_size[i], + block_pad=block_pad, dropout=dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), @@ -542,7 +543,6 @@ class Zipformer2EncoderLayer(nn.Module): feedforward_dim: int, dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, - block_size: int = 32, causal: bool = False, attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), @@ -576,14 +576,14 @@ class Zipformer2EncoderLayer(nn.Module): self.self_attn_weights = RelPositionMultiheadAttentionWeights( embed_dim, pos_dim=pos_dim, num_heads=num_heads, query_head_dim=query_head_dim, pos_head_dim=pos_head_dim, - block_size=block_size, dropout=0.0, + dropout=0.0, ) self.self_attn1 = SelfAttention(embed_dim, num_heads, - value_head_dim, block_size=block_size) + value_head_dim) self.self_attn2 = SelfAttention(embed_dim, num_heads, - value_head_dim, block_size=block_size) + value_head_dim) self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, @@ -598,8 +598,7 @@ class Zipformer2EncoderLayer(nn.Module): dropout) self.nonlin_attention = NonlinAttention(embed_dim, - hidden_channels=3 * embed_dim // 4, - block_size=block_size) + hidden_channels=3 * embed_dim // 4) self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel, @@ -680,6 +679,8 @@ class Zipformer2EncoderLayer(nn.Module): self, src: Tensor, pos_emb: Tensor, + block_size: int = 0, + block_pad: int = 16, chunk_size: int = -1, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, @@ -689,6 +690,8 @@ class Zipformer2EncoderLayer(nn.Module): Args: src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + block_size: size of block + block_pad: pad size at each side of block chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. feature_mask: something that broadcasts with src, that we'll multiply `src` by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) @@ -710,12 +713,21 @@ class Zipformer2EncoderLayer(nn.Module): attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) + if block_size == 0: + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + else: + attn_weights = self.self_attn_weights.forward_block( + src, + pos_emb=pos_emb, + block_size=block_size, + block_pad=block_pad, + key_padding_mask=src_key_padding_mask, + ) src = src + self.feed_forward1(src) @@ -733,11 +745,20 @@ class Zipformer2EncoderLayer(nn.Module): selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) - na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) + if block_size == 0: + na = self.nonlin_attention(src, selected_attn_weights) + else: + na = self.nonlin_attention.forward_block( + src, selected_attn_weights, block_size=block_size, block_pad=block_pad) + na = self.balancer_na(na) src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) - self_attn = self.self_attn1(src, attn_weights) + if block_size == 0: + self_attn = self.self_attn1(src, attn_weights) + else: + self_attn = self.self_attn1.forward_block( + src, attn_weights, block_size=block_size, block_pad=block_pad) src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) @@ -759,7 +780,11 @@ class Zipformer2EncoderLayer(nn.Module): # bypass in the middle of the layer. src = self.bypass_mid(src_orig, src) - self_attn = self.self_attn2(src, attn_weights) + if block_size == 0: + self_attn = self.self_attn2(src, attn_weights) + else: + self_attn = self.self_attn2.forward_block( + src, attn_weights, block_size=block_size, block_pad=block_pad) src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) @@ -925,10 +950,11 @@ class Zipformer2Encoder(nn.Module): encoder_layer: nn.Module, num_layers: int, pos_dim: int, - block_size: int, + max_block_size: int, dropout: float, warmup_begin: float, warmup_end: float, + block_pad: int = 16, initial_layerdrop_rate: float = 0.5, final_layerdrop_rate: float = 0.05, ) -> None: @@ -940,7 +966,8 @@ class Zipformer2Encoder(nn.Module): [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) self.num_layers = num_layers - self.block_size = block_size + self.max_block_size = max_block_size + self.block_pad = block_pad assert 0 <= warmup_begin <= warmup_end @@ -976,7 +1003,20 @@ class Zipformer2Encoder(nn.Module): Returns: a Tensor with the same shape as src. """ - pos_emb = self.encoder_pos(src, block_size=self.block_size) + seq_len = src.size(0) + max_block_size = self.max_block_size + block_pad = self.block_pad + if seq_len > max_block_size: + num_blocks = math.ceil(seq_len / max_block_size) + block_size = math.ceil(seq_len / num_blocks) + pos_emb = self.encoder_pos(src, rel_pos=block_size + block_pad) + # if __name__ == "__main__": + if random.random() < 0.2: + logging.info(f"seq_len={seq_len}, block_size={block_size}") + else: + pos_emb = self.encoder_pos(src) + block_size = 0 + output = src if not torch.jit.is_scripting() and not torch.jit.is_tracing(): @@ -986,6 +1026,8 @@ class Zipformer2Encoder(nn.Module): output = mod( output, pos_emb, + block_size=block_size, + block_pad=block_pad, chunk_size=chunk_size, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, @@ -1371,7 +1413,7 @@ class CompactRelPositionalEncoding(torch.nn.Module): self.pe = pe.to(dtype=x.dtype) - def forward(self, x: Tensor, block_size: int = 0) -> Tensor: + def forward(self, x: Tensor, rel_pos: int = 0) -> Tensor: """Create positional encoding. Args: @@ -1382,9 +1424,8 @@ class CompactRelPositionalEncoding(torch.nn.Module): positional embedding, of shape (1, 2*time-1, `*`) or (1, 4*block_size-1, `*`). """ self.extend_pe(x) - rel_pos = 2 * block_size if block_size != 0 else x.size(0) - # length of positive side: 2 * block_size - # length of negative side: 2 * block_size + if rel_pos == 0: + rel_pos = x.size(0) pos_emb = self.pe[ self.pe.size(0) // 2 - rel_pos @@ -1423,7 +1464,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module): num_heads: int, query_head_dim: int, pos_head_dim: int, - block_size: int, dropout: float = 0.0, pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)) @@ -1433,7 +1473,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self.num_heads = num_heads self.query_head_dim = query_head_dim self.pos_head_dim = pos_head_dim - self.block_size = block_size self.dropout = dropout self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) self.name = None # will be overwritten in training code; for diagnostics. @@ -1486,11 +1525,158 @@ class RelPositionMultiheadAttentionWeights(nn.Module): pos_emb: Tensor, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[...,0:query_dim] + k = x[...,query_dim:2*query_dim] + # p is the position-encoding query + p = x[...,2*query_dim:] + assert p.shape[-1] == num_heads * pos_head_dim + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + use_pos_scores = False + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # We can't put random.random() in the same line + use_pos_scores = True + elif not self.training or random.random() >= float(self.pos_emb_skip_rate): + use_pos_scores = True + + if use_pos_scores: + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) + else: + pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len), + (pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2)-pos_scores.stride(3), + pos_scores.stride(3)), + storage_offset=pos_scores.stride(3) * (seq_len - 1)) + + attn_scores = attn_scores + pos_scores + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 50.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt(attn_scores, + limit=25.0, + penalty=1.0e-04, + name=self.name) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001 and not self.training: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def forward_block( + self, + x: Tensor, + pos_emb: Tensor, + block_size: int, + block_pad: int, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, ) -> Tensor: r""" Args: x: input of shape (seq_len, batch_size, embed_dim) pos_emb: Positional embedding tensor, of shape (1, 4*block_size-1, pos_dim) + block_size: size of block + block_pad: pad size at each side of block key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that are True in this mask will be ignored as sources in the attention weighting. attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), @@ -1524,14 +1710,13 @@ class RelPositionMultiheadAttentionWeights(nn.Module): p = self.copy_pos_query(p) # for diagnostics only, does nothing. # divide into blocks by unfold function - block_size = self.block_size num_blocks = (seq_len + block_size - 1) // block_size pad_len = num_blocks * block_size - seq_len # (kernel, batch_size * num_blocks, channel) q_blocks = unfold(q, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0) p_blocks = unfold(p, pad_len, num_blocks, kernel=block_size, stride=block_size, padding=0) - k_blocks = unfold(k, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size) + k_blocks = unfold(k, pad_len, num_blocks, kernel=block_size + 2 * block_pad, stride=block_size, padding=block_pad) # time1 refers to target, time2 refers to source. time1 = q_blocks.size(0) @@ -1620,7 +1805,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # (time2, new_batch_size) attn_offsets = unfold( attn_offsets, pad_len, num_blocks, - kernel=block_size * 3, stride=block_size, padding=block_size, + kernel=block_size + 2 * block_pad, stride=block_size, padding=block_pad, ).squeeze(-1) # Used for the blocks are all padding @@ -1787,10 +1972,8 @@ class SelfAttention(nn.Module): embed_dim: int, num_heads: int, value_head_dim: int, - block_size: int, ) -> None: super().__init__() - self.block_size = block_size self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) @@ -1808,6 +1991,44 @@ class SelfAttention(nn.Module): self, x: Tensor, attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = x.permute(2, 1, 0, 3).contiguous().view( + seq_len, batch_size, num_heads * value_head_dim) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + def forward_block( + self, + x: Tensor, + attn_weights: Tensor, + block_size: int, + block_pad: int, ) -> Tensor: """ Args: @@ -1817,6 +2038,8 @@ class SelfAttention(nn.Module): interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len), where num_blocks = (seq_len + block_size - 1) // block_size. Expect attn_weights.sum(dim=-1) == 1. + block_size: size of block + block_pad: pad size at each side of block Returns: a tensor with the same shape as x. """ @@ -1824,19 +2047,18 @@ class SelfAttention(nn.Module): num_heads = attn_weights.shape[0] # divide into blocks by unfold function - block_size = self.block_size num_blocks = (seq_len + block_size - 1) // block_size pad_len = num_blocks * block_size - seq_len new_batch_size = batch_size * num_blocks time1 = block_size # target length - time2 = 3 * block_size # source length + time2 = block_size + 2 * block_pad # source length assert attn_weights.shape == (num_heads, new_batch_size, time1, time2) x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) # (time2, new_batch_size, channel) - x_blocks = unfold(x, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size) + x_blocks = unfold(x, pad_len, num_blocks, kernel=time2, stride=block_size, padding=block_pad) x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3) # now x: (num_heads, new_batch_size, time2, value_head_dim) @@ -1960,12 +2182,10 @@ class NonlinAttention(nn.Module): self, channels: int, hidden_channels: int, - block_size: int, ) -> None: super().__init__() self.hidden_channels = hidden_channels - self.block_size = block_size self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) @@ -2004,6 +2224,55 @@ class NonlinAttention(nn.Module): self, x: Tensor, attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) +attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=-1) + + # s will go through tanh. + + s = self.balancer(s) + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + def forward_block( + self, + x: Tensor, + attn_weights: Tensor, + block_size: int, + block_pad: int, ) -> Tensor: """. Args: @@ -2013,6 +2282,8 @@ class NonlinAttention(nn.Module): interpreted as (hum_heads, batch_size * num_blocks, tgt_seq_len, src_seq_len), where num_blocks = (seq_len + block_size - 1) // block_size. Expect attn_weights.sum(dim=-1) == 1. + block_size: size of block + block_pad: pad size at each side of block Returns: a Tensor with the same shape as x """ @@ -2037,17 +2308,16 @@ class NonlinAttention(nn.Module): num_heads = attn_weights.shape[0] # divide into blocks by unfold function - block_size = self.block_size num_blocks = (seq_len + block_size - 1) // block_size pad_len = num_blocks * block_size - seq_len new_batch_size = batch_size * num_blocks time1 = block_size # target length - time2 = 3 * block_size # source length + time2 = block_size + 2 * block_pad # source length assert attn_weights.shape == (num_heads, new_batch_size, time1, time2) # (time2, new_batch_size, channel) - x_blocks = unfold(x, pad_len, num_blocks, kernel=block_size * 3, stride=block_size, padding=block_size) + x_blocks = unfold(x, pad_len, num_blocks, kernel=time2, stride=block_size, padding=block_pad) x_blocks = x_blocks.reshape(time2, new_batch_size, num_heads, -1).permute(2, 1, 0, 3) # now x: (num_heads, new_batch_size, time2, head_dim) @@ -2315,13 +2585,14 @@ def _test_zipformer_main(causal: bool = False): c = Zipformer2( encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), downsampling_factor=(1, 2), - block_size=4, + max_block_size=14, + block_pad=1, causal=causal, chunk_size=(4,) if causal else (-1,), left_context_frames=(64,) ) batch_size = 2 - seq_len = 14 + seq_len = 29 # Just make sure the forward pass runs. x = torch.randn(seq_len, batch_size, 64) From 3dc33515c0bcba749a775cf08b8aba546763fb66 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 4 Aug 2023 10:26:52 +0800 Subject: [PATCH 5/5] split utterance over 512 frames into overlapping chunks --- egs/librispeech/ASR/zipformer/scaling.py | 30 +++- egs/librispeech/ASR/zipformer/zipformer.py | 152 ++++++++++++++------- 2 files changed, 131 insertions(+), 51 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 6c315b8e8..469ad3c6c 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1612,7 +1612,7 @@ def unfold( blocks: (kernel, batch_size * num_blocks, channel) """ seq_len, batch_size, channel = x.size() - x = x.permute(1, 2, 0) # (B, D, T) + x = x.permute(1, 2, 0) # (batch_size, channel, seq_len) x = nn.functional.pad(x, pad=(0, x_pad), value=0.0) @@ -1629,6 +1629,34 @@ def unfold( return blocks +def fold( + blocks: Tensor, seq_len: int, x_pad: int, num_blocks: int, kernel: int, stride: int, padding: int +) -> Tensor: + """ + Args: + blocks: (kernel, batch_size * num_blocks, channel) + Returns: + x: (seq_len, batch_size, channel) + """ + batch_size = blocks.size(1) // num_blocks + channel = blocks.size(2) + + blocks = blocks.reshape(kernel, batch_size, num_blocks, channel) + blocks = blocks.permute(1, 3, 0, 2).reshape(batch_size, channel * kernel, num_blocks) + + x = nn.functional.fold( + blocks, + output_size=(seq_len + x_pad, 1), + kernel_size=(kernel, 1), + padding=(padding, 0), + stride=(stride, 1), + ) + x = x.squeeze(-1).permute(2, 0, 1) + x = x[:seq_len] # (seq_len, batch_size, channel) + + return x + + def _test_whiten(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"_test_whiten(): proportion = {proportion}") diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7985a11fb..f56dbe6e9 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -39,6 +39,7 @@ from scaling import ( FloatLike, limit_param_value, convert_num_channels, + fold, unfold, ) from torch import Tensor, nn @@ -679,10 +680,10 @@ class Zipformer2EncoderLayer(nn.Module): self, src: Tensor, pos_emb: Tensor, - block_size: int = 0, - block_pad: int = 16, chunk_size: int = -1, attn_mask: Optional[Tensor] = None, + attn_offsets: Optional[Tensor] = None, + all_pad_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """ @@ -713,21 +714,13 @@ class Zipformer2EncoderLayer(nn.Module): attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 # attn_weights: (num_heads, batch_size, seq_len, seq_len) - if block_size == 0: - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - else: - attn_weights = self.self_attn_weights.forward_block( - src, - pos_emb=pos_emb, - block_size=block_size, - block_pad=block_pad, - key_padding_mask=src_key_padding_mask, - ) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + attn_offsets=attn_offsets, + all_pad_mask=all_pad_mask, + ) src = src + self.feed_forward1(src) @@ -745,20 +738,12 @@ class Zipformer2EncoderLayer(nn.Module): selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) - if block_size == 0: - na = self.nonlin_attention(src, selected_attn_weights) - else: - na = self.nonlin_attention.forward_block( - src, selected_attn_weights, block_size=block_size, block_pad=block_pad) + na = self.nonlin_attention(src, selected_attn_weights) na = self.balancer_na(na) src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) - if block_size == 0: - self_attn = self.self_attn1(src, attn_weights) - else: - self_attn = self.self_attn1.forward_block( - src, attn_weights, block_size=block_size, block_pad=block_pad) + self_attn = self.self_attn1(src, attn_weights) src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) @@ -780,11 +765,7 @@ class Zipformer2EncoderLayer(nn.Module): # bypass in the middle of the layer. src = self.bypass_mid(src_orig, src) - if block_size == 0: - self_attn = self.self_attn2(src, attn_weights) - else: - self_attn = self.self_attn2.forward_block( - src, attn_weights, block_size=block_size, block_pad=block_pad) + self_attn = self.self_attn2(src, attn_weights) src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) @@ -994,7 +975,7 @@ class Zipformer2Encoder(nn.Module): src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + by at every layer: if a Tensor, likely of shape (1, batch_size, embedding_dim) attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). True means masked position. May be None. @@ -1003,20 +984,71 @@ class Zipformer2Encoder(nn.Module): Returns: a Tensor with the same shape as src. """ - seq_len = src.size(0) + seq_len, batch_size, channel = src.size() max_block_size = self.max_block_size block_pad = self.block_pad + if seq_len > max_block_size: + # divide into blocks with overlaps num_blocks = math.ceil(seq_len / max_block_size) block_size = math.ceil(seq_len / num_blocks) - pos_emb = self.encoder_pos(src, rel_pos=block_size + block_pad) - # if __name__ == "__main__": - if random.random() < 0.2: - logging.info(f"seq_len={seq_len}, block_size={block_size}") + pad_len = num_blocks * block_size - seq_len + kernel_size = block_size + 2 * block_pad + if random.random() < 0.2 or __name__ == "__main__": + logging.info(f"seq_len={seq_len}, block_size={block_size}, pad_len={pad_len}") + + # (block_size + 2 * block_pad, batch_size * num_blocks, channel) + src = unfold( + src, pad_len, num_blocks, + kernel=kernel_size, stride=block_size, padding=block_pad + ) + + # Used to mask out the padding positions + attn_offsets = torch.ones(batch_size, seq_len, device=src.device) + + if src_key_padding_mask is not None: + assert src_key_padding_mask.shape == (batch_size, seq_len), src_key_padding_mask.shape + attn_offsets = attn_offsets.masked_fill(src_key_padding_mask, 0.0) # 0 at padding positions + # (seq_len, batch, 1) + attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(-1) + # (kernel_size, new_batch_size) + attn_offsets = unfold( + attn_offsets, pad_len, num_blocks, + kernel=kernel_size, stride=block_size, padding=block_pad, + ).squeeze(-1) + + # Used for the blocks are all padding + all_pad_mask = (attn_offsets.sum(dim=0, keepdim=True) == 0) # (1, new_batch_size) + all_pad_mask = all_pad_mask.unsqueeze(-1).unsqueeze(-1) # (1, new_batch_size, 1, 1) + + # (new_batch_size, kernel_size) + src_key_padding_mask = (attn_offsets == 0).transpose(0, 1) + + attn_offsets = 1 - attn_offsets # 1 at padding positions + attn_offsets[attn_offsets != 0] = -1000 + + # (1, new_batch_size, 1, kernel) + attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(1).unsqueeze(0) + + # feature_mask: (1, batch_size, channel) + if isinstance(feature_mask, Tensor): + feature_mask = feature_mask.unsqueeze(2).expand(-1, -1, num_blocks, -1) + # now (kernel_size, batch_size, num_blocks, channel) + feature_mask = feature_mask.reshape(1, batch_size * num_blocks, channel) else: - pos_emb = self.encoder_pos(src) block_size = 0 + # Used to mask out the padding positions + attn_offsets = torch.zeros(batch_size, seq_len, device=src.device) + if src_key_padding_mask is not None: + assert src_key_padding_mask.shape == (batch_size, seq_len), src_key_padding_mask.shape + attn_offsets = attn_offsets.masked_fill(src_key_padding_mask, -1000) # 0 at padding positions + # (1, batch_size, 1, seq_len) + attn_offsets = attn_offsets.unsqueeze(1).unsqueeze(0) + + all_pad_mask = None + + pos_emb = self.encoder_pos(src) output = src if not torch.jit.is_scripting() and not torch.jit.is_tracing(): @@ -1026,16 +1058,31 @@ class Zipformer2Encoder(nn.Module): output = mod( output, pos_emb, - block_size=block_size, - block_pad=block_pad, chunk_size=chunk_size, attn_mask=attn_mask, + attn_offsets=attn_offsets, + all_pad_mask=all_pad_mask, src_key_padding_mask=src_key_padding_mask, ) if not torch.jit.is_scripting() and not torch.jit.is_tracing(): output = output * feature_mask + if seq_len > max_block_size: + # overlap-and-add + output = fold( + output, seq_len, pad_len, num_blocks, + kernel=kernel_size, stride=block_size, padding=block_pad + ) # (seq_len, batch_size, channel) + mask = torch.ones( + kernel_size, batch_size * num_blocks, 1, device=src.device, + ) + mask = fold( + mask, seq_len, pad_len, num_blocks, + kernel=kernel_size, stride=block_size, padding=block_pad + ) # (seq_len, batch_size, 1) + output = output / mask + return output def streaming_forward( @@ -1523,7 +1570,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self, x: Tensor, pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, + attn_offsets: Optional[Tensor] = None, + all_pad_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, ) -> Tensor: r""" @@ -1631,6 +1679,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) if attn_mask is not None: + assert attn_mask is None assert attn_mask.dtype == torch.bool # use -1000 to avoid nan's where attn_mask and key_padding_mask make # all scores zero. It's important that this be large enough that exp(-1000) @@ -1638,12 +1687,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # compares the final weights with zero. attn_scores = attn_scores.masked_fill(attn_mask, -1000) - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) + if attn_offsets is not None: + # attn_offsets: (1, batch_size, 1, seq_len) + # or (1, new_batch_size, 1, kernel) + attn_scores = attn_scores + attn_offsets # We use our own version of softmax, defined in scaling.py, which should # save a little of the memory used in backprop by, if we are in @@ -1651,6 +1698,11 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # half-precision output for backprop purposes. attn_weights = softmax(attn_scores, dim=-1) + if all_pad_mask is not None: + # For the blocks are all padding + # all_pad_mask: (1, new_batch_size, 1, 1) + attn_weights = attn_weights.masked_fill(all_pad_mask, 0.0) + if torch.jit.is_scripting() or torch.jit.is_tracing(): pass elif random.random() < 0.001 and not self.training: @@ -2586,13 +2638,13 @@ def _test_zipformer_main(causal: bool = False): encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), downsampling_factor=(1, 2), max_block_size=14, - block_pad=1, + block_pad=2, causal=causal, chunk_size=(4,) if causal else (-1,), left_context_frames=(64,) ) batch_size = 2 - seq_len = 29 + seq_len = 27 # Just make sure the forward pass runs. x = torch.randn(seq_len, batch_size, 64)