From 3dc33515c0bcba749a775cf08b8aba546763fb66 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 4 Aug 2023 10:26:52 +0800 Subject: [PATCH] split utterance over 512 frames into overlapping chunks --- egs/librispeech/ASR/zipformer/scaling.py | 30 +++- egs/librispeech/ASR/zipformer/zipformer.py | 152 ++++++++++++++------- 2 files changed, 131 insertions(+), 51 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 6c315b8e8..469ad3c6c 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1612,7 +1612,7 @@ def unfold( blocks: (kernel, batch_size * num_blocks, channel) """ seq_len, batch_size, channel = x.size() - x = x.permute(1, 2, 0) # (B, D, T) + x = x.permute(1, 2, 0) # (batch_size, channel, seq_len) x = nn.functional.pad(x, pad=(0, x_pad), value=0.0) @@ -1629,6 +1629,34 @@ def unfold( return blocks +def fold( + blocks: Tensor, seq_len: int, x_pad: int, num_blocks: int, kernel: int, stride: int, padding: int +) -> Tensor: + """ + Args: + blocks: (kernel, batch_size * num_blocks, channel) + Returns: + x: (seq_len, batch_size, channel) + """ + batch_size = blocks.size(1) // num_blocks + channel = blocks.size(2) + + blocks = blocks.reshape(kernel, batch_size, num_blocks, channel) + blocks = blocks.permute(1, 3, 0, 2).reshape(batch_size, channel * kernel, num_blocks) + + x = nn.functional.fold( + blocks, + output_size=(seq_len + x_pad, 1), + kernel_size=(kernel, 1), + padding=(padding, 0), + stride=(stride, 1), + ) + x = x.squeeze(-1).permute(2, 0, 1) + x = x[:seq_len] # (seq_len, batch_size, channel) + + return x + + def _test_whiten(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"_test_whiten(): proportion = {proportion}") diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7985a11fb..f56dbe6e9 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -39,6 +39,7 @@ from scaling import ( FloatLike, limit_param_value, convert_num_channels, + fold, unfold, ) from torch import Tensor, nn @@ -679,10 +680,10 @@ class Zipformer2EncoderLayer(nn.Module): self, src: Tensor, pos_emb: Tensor, - block_size: int = 0, - block_pad: int = 16, chunk_size: int = -1, attn_mask: Optional[Tensor] = None, + attn_offsets: Optional[Tensor] = None, + all_pad_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """ @@ -713,21 +714,13 @@ class Zipformer2EncoderLayer(nn.Module): attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 # attn_weights: (num_heads, batch_size, seq_len, seq_len) - if block_size == 0: - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - else: - attn_weights = self.self_attn_weights.forward_block( - src, - pos_emb=pos_emb, - block_size=block_size, - block_pad=block_pad, - key_padding_mask=src_key_padding_mask, - ) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + attn_offsets=attn_offsets, + all_pad_mask=all_pad_mask, + ) src = src + self.feed_forward1(src) @@ -745,20 +738,12 @@ class Zipformer2EncoderLayer(nn.Module): selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) - if block_size == 0: - na = self.nonlin_attention(src, selected_attn_weights) - else: - na = self.nonlin_attention.forward_block( - src, selected_attn_weights, block_size=block_size, block_pad=block_pad) + na = self.nonlin_attention(src, selected_attn_weights) na = self.balancer_na(na) src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) - if block_size == 0: - self_attn = self.self_attn1(src, attn_weights) - else: - self_attn = self.self_attn1.forward_block( - src, attn_weights, block_size=block_size, block_pad=block_pad) + self_attn = self.self_attn1(src, attn_weights) src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) @@ -780,11 +765,7 @@ class Zipformer2EncoderLayer(nn.Module): # bypass in the middle of the layer. src = self.bypass_mid(src_orig, src) - if block_size == 0: - self_attn = self.self_attn2(src, attn_weights) - else: - self_attn = self.self_attn2.forward_block( - src, attn_weights, block_size=block_size, block_pad=block_pad) + self_attn = self.self_attn2(src, attn_weights) src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) @@ -994,7 +975,7 @@ class Zipformer2Encoder(nn.Module): src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + by at every layer: if a Tensor, likely of shape (1, batch_size, embedding_dim) attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). True means masked position. May be None. @@ -1003,20 +984,71 @@ class Zipformer2Encoder(nn.Module): Returns: a Tensor with the same shape as src. """ - seq_len = src.size(0) + seq_len, batch_size, channel = src.size() max_block_size = self.max_block_size block_pad = self.block_pad + if seq_len > max_block_size: + # divide into blocks with overlaps num_blocks = math.ceil(seq_len / max_block_size) block_size = math.ceil(seq_len / num_blocks) - pos_emb = self.encoder_pos(src, rel_pos=block_size + block_pad) - # if __name__ == "__main__": - if random.random() < 0.2: - logging.info(f"seq_len={seq_len}, block_size={block_size}") + pad_len = num_blocks * block_size - seq_len + kernel_size = block_size + 2 * block_pad + if random.random() < 0.2 or __name__ == "__main__": + logging.info(f"seq_len={seq_len}, block_size={block_size}, pad_len={pad_len}") + + # (block_size + 2 * block_pad, batch_size * num_blocks, channel) + src = unfold( + src, pad_len, num_blocks, + kernel=kernel_size, stride=block_size, padding=block_pad + ) + + # Used to mask out the padding positions + attn_offsets = torch.ones(batch_size, seq_len, device=src.device) + + if src_key_padding_mask is not None: + assert src_key_padding_mask.shape == (batch_size, seq_len), src_key_padding_mask.shape + attn_offsets = attn_offsets.masked_fill(src_key_padding_mask, 0.0) # 0 at padding positions + # (seq_len, batch, 1) + attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(-1) + # (kernel_size, new_batch_size) + attn_offsets = unfold( + attn_offsets, pad_len, num_blocks, + kernel=kernel_size, stride=block_size, padding=block_pad, + ).squeeze(-1) + + # Used for the blocks are all padding + all_pad_mask = (attn_offsets.sum(dim=0, keepdim=True) == 0) # (1, new_batch_size) + all_pad_mask = all_pad_mask.unsqueeze(-1).unsqueeze(-1) # (1, new_batch_size, 1, 1) + + # (new_batch_size, kernel_size) + src_key_padding_mask = (attn_offsets == 0).transpose(0, 1) + + attn_offsets = 1 - attn_offsets # 1 at padding positions + attn_offsets[attn_offsets != 0] = -1000 + + # (1, new_batch_size, 1, kernel) + attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(1).unsqueeze(0) + + # feature_mask: (1, batch_size, channel) + if isinstance(feature_mask, Tensor): + feature_mask = feature_mask.unsqueeze(2).expand(-1, -1, num_blocks, -1) + # now (kernel_size, batch_size, num_blocks, channel) + feature_mask = feature_mask.reshape(1, batch_size * num_blocks, channel) else: - pos_emb = self.encoder_pos(src) block_size = 0 + # Used to mask out the padding positions + attn_offsets = torch.zeros(batch_size, seq_len, device=src.device) + if src_key_padding_mask is not None: + assert src_key_padding_mask.shape == (batch_size, seq_len), src_key_padding_mask.shape + attn_offsets = attn_offsets.masked_fill(src_key_padding_mask, -1000) # 0 at padding positions + # (1, batch_size, 1, seq_len) + attn_offsets = attn_offsets.unsqueeze(1).unsqueeze(0) + + all_pad_mask = None + + pos_emb = self.encoder_pos(src) output = src if not torch.jit.is_scripting() and not torch.jit.is_tracing(): @@ -1026,16 +1058,31 @@ class Zipformer2Encoder(nn.Module): output = mod( output, pos_emb, - block_size=block_size, - block_pad=block_pad, chunk_size=chunk_size, attn_mask=attn_mask, + attn_offsets=attn_offsets, + all_pad_mask=all_pad_mask, src_key_padding_mask=src_key_padding_mask, ) if not torch.jit.is_scripting() and not torch.jit.is_tracing(): output = output * feature_mask + if seq_len > max_block_size: + # overlap-and-add + output = fold( + output, seq_len, pad_len, num_blocks, + kernel=kernel_size, stride=block_size, padding=block_pad + ) # (seq_len, batch_size, channel) + mask = torch.ones( + kernel_size, batch_size * num_blocks, 1, device=src.device, + ) + mask = fold( + mask, seq_len, pad_len, num_blocks, + kernel=kernel_size, stride=block_size, padding=block_pad + ) # (seq_len, batch_size, 1) + output = output / mask + return output def streaming_forward( @@ -1523,7 +1570,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self, x: Tensor, pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, + attn_offsets: Optional[Tensor] = None, + all_pad_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, ) -> Tensor: r""" @@ -1631,6 +1679,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) if attn_mask is not None: + assert attn_mask is None assert attn_mask.dtype == torch.bool # use -1000 to avoid nan's where attn_mask and key_padding_mask make # all scores zero. It's important that this be large enough that exp(-1000) @@ -1638,12 +1687,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # compares the final weights with zero. attn_scores = attn_scores.masked_fill(attn_mask, -1000) - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) + if attn_offsets is not None: + # attn_offsets: (1, batch_size, 1, seq_len) + # or (1, new_batch_size, 1, kernel) + attn_scores = attn_scores + attn_offsets # We use our own version of softmax, defined in scaling.py, which should # save a little of the memory used in backprop by, if we are in @@ -1651,6 +1698,11 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # half-precision output for backprop purposes. attn_weights = softmax(attn_scores, dim=-1) + if all_pad_mask is not None: + # For the blocks are all padding + # all_pad_mask: (1, new_batch_size, 1, 1) + attn_weights = attn_weights.masked_fill(all_pad_mask, 0.0) + if torch.jit.is_scripting() or torch.jit.is_tracing(): pass elif random.random() < 0.001 and not self.training: @@ -2586,13 +2638,13 @@ def _test_zipformer_main(causal: bool = False): encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), downsampling_factor=(1, 2), max_block_size=14, - block_pad=1, + block_pad=2, causal=causal, chunk_size=(4,) if causal else (-1,), left_context_frames=(64,) ) batch_size = 2 - seq_len = 29 + seq_len = 27 # Just make sure the forward pass runs. x = torch.randn(seq_len, batch_size, 64)