From 06232dce2e295fa91f227c520e99848910ef9056 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 5 Jun 2024 14:54:42 +0800 Subject: [PATCH 1/2] WIP: Begin to add Contextual positional encoding --- egs/librispeech/ASR/zipformer/test_cope.py | 22 ++++++++ egs/librispeech/ASR/zipformer/zipformer.py | 58 +++++++++++++++++++++- 2 files changed, 79 insertions(+), 1 deletion(-) create mode 100755 egs/librispeech/ASR/zipformer/test_cope.py diff --git a/egs/librispeech/ASR/zipformer/test_cope.py b/egs/librispeech/ASR/zipformer/test_cope.py new file mode 100755 index 000000000..5eb6ccfd9 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/test_cope.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +from zipformer import ContextualPositionalEncoding + + +def test(): + embed_dim = 5 + npos_max = 10 + cope = ContextualPositionalEncoding(embed_dim=embed_dim, npos_max=npos_max) + q = torch.rand(2, 3, 4, embed_dim) + qk = torch.rand(2, 3, 4, 6) + + p = cope(q=q, qk=qk) + print(p.shape) + + +def main(): + test() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 69059287b..6a94e3ab0 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -95,6 +95,7 @@ class Zipformer2(EncoderInterface): context chunks for causal training; will be rounded to a number of chunks. Must not be less than cnn_module_kernel (after factoring in rounding and downsampling); an error will be thrown if this is violated. + use_cope (bool): If true, use contextual positional encoding """ def __init__( @@ -116,6 +117,7 @@ class Zipformer2(EncoderInterface): causal: bool = False, chunk_size: Tuple[int] = [-1], left_context_frames: Tuple[int] = [-1], + use_cope: bool = False, ) -> None: super(Zipformer2, self).__init__() @@ -183,6 +185,7 @@ class Zipformer2(EncoderInterface): warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + use_cope=use_cope, ) if downsampling_factor[i] != 1: @@ -1017,6 +1020,7 @@ class Zipformer2Encoder(nn.Module): warmup_end: float, initial_layerdrop_rate: float = 0.5, final_layerdrop_rate: float = 0.05, + use_cope: bool = False, ) -> None: super().__init__() self.encoder_pos = CompactRelPositionalEncoding( @@ -1372,6 +1376,54 @@ class SimpleUpsample(torch.nn.Module): return src +class ContextualPositionalEncoding(torch.nn.Module): + """ + This class implements the following paper: + Contextual Position Encoding: Learning to Count What's Important + https://arxiv.org/abs/2405.18719 + + Args: + embed_dim: Embedding dimension. + npos_max: The maximum context size. + """ + + def __init__(self, embed_dim: int, npos_max: int): + super().__init__() + self.npos_max = npos_max + self.embedding = nn.Embedding( + num_embeddings=npos_max, + embedding_dim=embed_dim, + ) + + def forward(self, q: torch.Tensor, qk: torch.Tensor) -> torch.Tensor: + """ + Args: + q (torch.Tensor): A tensor of shape (head, batch, time1, query_head_dim) + qk (torch.Tensor): A tensor of shape (head, batch, time1, time2) + Returns: + Return a tensor of shape (head, batch, time1, npos_max) + """ + gates = torch.sigmoid(qk) + pos = gates.sum(dim=-1, keepdim=True) # (head, batch, dim1, 1) + # Note: We don't use cumulative sum here for non-streaming + # speech recognition + + pos = pos.clamp(max=self.npos_max - 1) + pos_ceil = pos.ceil().long() + pos_floor = pos.floor().long() + logits_int = torch.matmul( + q, self.embedding.weight.t() + ) # (head, batch, time1, npos_max) + logits_cell = logits_int.gather(-1, pos_ceil.expand(*logits_int.shape)) + logits_floor = logits_int.gather(-1, pos_floor.expand(*logits_int.shape)) + + w = pos - pos_floor + return logits_cell * w + logits_floor * (1 - w) + + def streaming_forward(self): + raise RuntimeError("To be implemented") + + class CompactRelPositionalEncoding(torch.nn.Module): """ Relative positional encoding module. This version is "compact" meaning it is able to encode @@ -1609,7 +1661,11 @@ class RelPositionMultiheadAttentionWeights(nn.Module): k = x[..., query_dim : 2 * query_dim] # p is the position-encoding query p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim, (p.shape[-1], num_heads, pos_head_dim) + assert p.shape[-1] == num_heads * pos_head_dim, ( + p.shape[-1], + num_heads, + pos_head_dim, + ) q = self.copy_query(q) # for diagnostics only, does nothing. k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. From 36808b89406d97ee5ab68c43136a509eb0d193fc Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 3 Jul 2024 13:43:03 +0800 Subject: [PATCH 2/2] fix the implementation of CoPE --- egs/librispeech/ASR/zipformer/test_cope.py | 8 +++- egs/librispeech/ASR/zipformer/zipformer.py | 49 ++++++++++++++++++---- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/test_cope.py b/egs/librispeech/ASR/zipformer/test_cope.py index 5eb6ccfd9..00acbcb1d 100755 --- a/egs/librispeech/ASR/zipformer/test_cope.py +++ b/egs/librispeech/ASR/zipformer/test_cope.py @@ -1,14 +1,17 @@ #!/usr/bin/env python3 +import torch from zipformer import ContextualPositionalEncoding def test(): embed_dim = 5 npos_max = 10 + cope = ContextualPositionalEncoding(embed_dim=embed_dim, npos_max=npos_max) - q = torch.rand(2, 3, 4, embed_dim) - qk = torch.rand(2, 3, 4, 6) + q = torch.rand(2, 3, npos_max, embed_dim) + + qk = torch.rand(2, 3, npos_max, npos_max) p = cope(q=q, qk=qk) print(p.shape) @@ -19,4 +22,5 @@ def main(): if __name__ == "__main__": + torch.manual_seed(20240703) main() diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 6a94e3ab0..a62cd54f1 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1402,26 +1402,59 @@ class ContextualPositionalEncoding(torch.nn.Module): qk (torch.Tensor): A tensor of shape (head, batch, time1, time2) Returns: Return a tensor of shape (head, batch, time1, npos_max) + + Note the implementation assumes time1 == time2 and npos_max <= time2. + The implementation is reasonable for the streaming ASR encoder where + only self attention is used. """ + # The implementation on page 13 Listing 1 from the paper does not use + # a mask to ensure that only gates[:, :, i, j] where j < i is computed. + # + # Here we fix that by introducing a mask + mask = torch.triu( + torch.full((qk.size(3), qk.size(3)), True, dtype=torch.bool), + diagonal=0, + ) + # + # if qk.size(3) is 4, mask is + # + # tensor([[ True, True, True, True], + # [False, True, True, True], + # [False, False, True, True], + # [False, False, False, True]]) + # + # mask[i, j] is True if i >= j gates = torch.sigmoid(qk) - pos = gates.sum(dim=-1, keepdim=True) # (head, batch, dim1, 1) - # Note: We don't use cumulative sum here for non-streaming - # speech recognition + + # We don't use an in-place operation here for the sake of autograd + gates = gates.masked_fill(mask, 0) + + # cumsum() is an inclusive sum in PyTorch + pos = gates.flip(-1).cumsum(dim=-1).flip(-1) # (head, batch, time1, time2) + # pos[:, :, i, j] should be 0 for j >= i + # pos[:, :, i, j] contains the position between i and j. If gates + # is a 0-1 matrix, then pos[:, :, i, j] equals to i - j (for j < i) + # Note: The paper says on page 4 it equals to i - j + 1 instead of i - j. pos = pos.clamp(max=self.npos_max - 1) pos_ceil = pos.ceil().long() pos_floor = pos.floor().long() + + # We assume query_head_dim equals to embed_dim + logits_int = torch.matmul( q, self.embedding.weight.t() ) # (head, batch, time1, npos_max) - logits_cell = logits_int.gather(-1, pos_ceil.expand(*logits_int.shape)) - logits_floor = logits_int.gather(-1, pos_floor.expand(*logits_int.shape)) + + # We assume that npos_max <= time2 + logits_cell = logits_int.gather(-1, pos_ceil) + logits_floor = logits_int.gather(-1, pos_floor) w = pos - pos_floor - return logits_cell * w + logits_floor * (1 - w) - def streaming_forward(self): - raise RuntimeError("To be implemented") + # Note: The code in the paper on page 13 is correct + # while the description on page 4 equation (5) is wrong + return logits_cell * w + logits_floor * (1 - w) class CompactRelPositionalEncoding(torch.nn.Module):