mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
fix the implementation of CoPE
This commit is contained in:
parent
06232dce2e
commit
36808b8940
@ -1,14 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
from zipformer import ContextualPositionalEncoding
|
||||
|
||||
|
||||
def test():
|
||||
embed_dim = 5
|
||||
npos_max = 10
|
||||
|
||||
cope = ContextualPositionalEncoding(embed_dim=embed_dim, npos_max=npos_max)
|
||||
q = torch.rand(2, 3, 4, embed_dim)
|
||||
qk = torch.rand(2, 3, 4, 6)
|
||||
q = torch.rand(2, 3, npos_max, embed_dim)
|
||||
|
||||
qk = torch.rand(2, 3, npos_max, npos_max)
|
||||
|
||||
p = cope(q=q, qk=qk)
|
||||
print(p.shape)
|
||||
@ -19,4 +22,5 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20240703)
|
||||
main()
|
||||
|
@ -1402,26 +1402,59 @@ class ContextualPositionalEncoding(torch.nn.Module):
|
||||
qk (torch.Tensor): A tensor of shape (head, batch, time1, time2)
|
||||
Returns:
|
||||
Return a tensor of shape (head, batch, time1, npos_max)
|
||||
|
||||
Note the implementation assumes time1 == time2 and npos_max <= time2.
|
||||
The implementation is reasonable for the streaming ASR encoder where
|
||||
only self attention is used.
|
||||
"""
|
||||
# The implementation on page 13 Listing 1 from the paper does not use
|
||||
# a mask to ensure that only gates[:, :, i, j] where j < i is computed.
|
||||
#
|
||||
# Here we fix that by introducing a mask
|
||||
mask = torch.triu(
|
||||
torch.full((qk.size(3), qk.size(3)), True, dtype=torch.bool),
|
||||
diagonal=0,
|
||||
)
|
||||
#
|
||||
# if qk.size(3) is 4, mask is
|
||||
#
|
||||
# tensor([[ True, True, True, True],
|
||||
# [False, True, True, True],
|
||||
# [False, False, True, True],
|
||||
# [False, False, False, True]])
|
||||
#
|
||||
# mask[i, j] is True if i >= j
|
||||
gates = torch.sigmoid(qk)
|
||||
pos = gates.sum(dim=-1, keepdim=True) # (head, batch, dim1, 1)
|
||||
# Note: We don't use cumulative sum here for non-streaming
|
||||
# speech recognition
|
||||
|
||||
# We don't use an in-place operation here for the sake of autograd
|
||||
gates = gates.masked_fill(mask, 0)
|
||||
|
||||
# cumsum() is an inclusive sum in PyTorch
|
||||
pos = gates.flip(-1).cumsum(dim=-1).flip(-1) # (head, batch, time1, time2)
|
||||
# pos[:, :, i, j] should be 0 for j >= i
|
||||
# pos[:, :, i, j] contains the position between i and j. If gates
|
||||
# is a 0-1 matrix, then pos[:, :, i, j] equals to i - j (for j < i)
|
||||
# Note: The paper says on page 4 it equals to i - j + 1 instead of i - j.
|
||||
|
||||
pos = pos.clamp(max=self.npos_max - 1)
|
||||
pos_ceil = pos.ceil().long()
|
||||
pos_floor = pos.floor().long()
|
||||
|
||||
# We assume query_head_dim equals to embed_dim
|
||||
|
||||
logits_int = torch.matmul(
|
||||
q, self.embedding.weight.t()
|
||||
) # (head, batch, time1, npos_max)
|
||||
logits_cell = logits_int.gather(-1, pos_ceil.expand(*logits_int.shape))
|
||||
logits_floor = logits_int.gather(-1, pos_floor.expand(*logits_int.shape))
|
||||
|
||||
# We assume that npos_max <= time2
|
||||
logits_cell = logits_int.gather(-1, pos_ceil)
|
||||
logits_floor = logits_int.gather(-1, pos_floor)
|
||||
|
||||
w = pos - pos_floor
|
||||
return logits_cell * w + logits_floor * (1 - w)
|
||||
|
||||
def streaming_forward(self):
|
||||
raise RuntimeError("To be implemented")
|
||||
# Note: The code in the paper on page 13 is correct
|
||||
# while the description on page 4 equation (5) is wrong
|
||||
return logits_cell * w + logits_floor * (1 - w)
|
||||
|
||||
|
||||
class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
|
Loading…
x
Reference in New Issue
Block a user