mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 02:22:17 +00:00
fix the implementation of CoPE
This commit is contained in:
parent
06232dce2e
commit
36808b8940
@ -1,14 +1,17 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import torch
|
||||||
from zipformer import ContextualPositionalEncoding
|
from zipformer import ContextualPositionalEncoding
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
embed_dim = 5
|
embed_dim = 5
|
||||||
npos_max = 10
|
npos_max = 10
|
||||||
|
|
||||||
cope = ContextualPositionalEncoding(embed_dim=embed_dim, npos_max=npos_max)
|
cope = ContextualPositionalEncoding(embed_dim=embed_dim, npos_max=npos_max)
|
||||||
q = torch.rand(2, 3, 4, embed_dim)
|
q = torch.rand(2, 3, npos_max, embed_dim)
|
||||||
qk = torch.rand(2, 3, 4, 6)
|
|
||||||
|
qk = torch.rand(2, 3, npos_max, npos_max)
|
||||||
|
|
||||||
p = cope(q=q, qk=qk)
|
p = cope(q=q, qk=qk)
|
||||||
print(p.shape)
|
print(p.shape)
|
||||||
@ -19,4 +22,5 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(20240703)
|
||||||
main()
|
main()
|
||||||
|
@ -1402,26 +1402,59 @@ class ContextualPositionalEncoding(torch.nn.Module):
|
|||||||
qk (torch.Tensor): A tensor of shape (head, batch, time1, time2)
|
qk (torch.Tensor): A tensor of shape (head, batch, time1, time2)
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (head, batch, time1, npos_max)
|
Return a tensor of shape (head, batch, time1, npos_max)
|
||||||
|
|
||||||
|
Note the implementation assumes time1 == time2 and npos_max <= time2.
|
||||||
|
The implementation is reasonable for the streaming ASR encoder where
|
||||||
|
only self attention is used.
|
||||||
"""
|
"""
|
||||||
|
# The implementation on page 13 Listing 1 from the paper does not use
|
||||||
|
# a mask to ensure that only gates[:, :, i, j] where j < i is computed.
|
||||||
|
#
|
||||||
|
# Here we fix that by introducing a mask
|
||||||
|
mask = torch.triu(
|
||||||
|
torch.full((qk.size(3), qk.size(3)), True, dtype=torch.bool),
|
||||||
|
diagonal=0,
|
||||||
|
)
|
||||||
|
#
|
||||||
|
# if qk.size(3) is 4, mask is
|
||||||
|
#
|
||||||
|
# tensor([[ True, True, True, True],
|
||||||
|
# [False, True, True, True],
|
||||||
|
# [False, False, True, True],
|
||||||
|
# [False, False, False, True]])
|
||||||
|
#
|
||||||
|
# mask[i, j] is True if i >= j
|
||||||
gates = torch.sigmoid(qk)
|
gates = torch.sigmoid(qk)
|
||||||
pos = gates.sum(dim=-1, keepdim=True) # (head, batch, dim1, 1)
|
|
||||||
# Note: We don't use cumulative sum here for non-streaming
|
# We don't use an in-place operation here for the sake of autograd
|
||||||
# speech recognition
|
gates = gates.masked_fill(mask, 0)
|
||||||
|
|
||||||
|
# cumsum() is an inclusive sum in PyTorch
|
||||||
|
pos = gates.flip(-1).cumsum(dim=-1).flip(-1) # (head, batch, time1, time2)
|
||||||
|
# pos[:, :, i, j] should be 0 for j >= i
|
||||||
|
# pos[:, :, i, j] contains the position between i and j. If gates
|
||||||
|
# is a 0-1 matrix, then pos[:, :, i, j] equals to i - j (for j < i)
|
||||||
|
# Note: The paper says on page 4 it equals to i - j + 1 instead of i - j.
|
||||||
|
|
||||||
pos = pos.clamp(max=self.npos_max - 1)
|
pos = pos.clamp(max=self.npos_max - 1)
|
||||||
pos_ceil = pos.ceil().long()
|
pos_ceil = pos.ceil().long()
|
||||||
pos_floor = pos.floor().long()
|
pos_floor = pos.floor().long()
|
||||||
|
|
||||||
|
# We assume query_head_dim equals to embed_dim
|
||||||
|
|
||||||
logits_int = torch.matmul(
|
logits_int = torch.matmul(
|
||||||
q, self.embedding.weight.t()
|
q, self.embedding.weight.t()
|
||||||
) # (head, batch, time1, npos_max)
|
) # (head, batch, time1, npos_max)
|
||||||
logits_cell = logits_int.gather(-1, pos_ceil.expand(*logits_int.shape))
|
|
||||||
logits_floor = logits_int.gather(-1, pos_floor.expand(*logits_int.shape))
|
# We assume that npos_max <= time2
|
||||||
|
logits_cell = logits_int.gather(-1, pos_ceil)
|
||||||
|
logits_floor = logits_int.gather(-1, pos_floor)
|
||||||
|
|
||||||
w = pos - pos_floor
|
w = pos - pos_floor
|
||||||
return logits_cell * w + logits_floor * (1 - w)
|
|
||||||
|
|
||||||
def streaming_forward(self):
|
# Note: The code in the paper on page 13 is correct
|
||||||
raise RuntimeError("To be implemented")
|
# while the description on page 4 equation (5) is wrong
|
||||||
|
return logits_cell * w + logits_floor * (1 - w)
|
||||||
|
|
||||||
|
|
||||||
class CompactRelPositionalEncoding(torch.nn.Module):
|
class CompactRelPositionalEncoding(torch.nn.Module):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user