mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge 36808b89406d97ee5ab68c43136a509eb0d193fc into abd9437e6d5419a497707748eb935e50976c3b7b
This commit is contained in:
commit
f8354ee64d
26
egs/librispeech/ASR/zipformer/test_cope.py
Executable file
26
egs/librispeech/ASR/zipformer/test_cope.py
Executable file
@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
from zipformer import ContextualPositionalEncoding
|
||||
|
||||
|
||||
def test():
|
||||
embed_dim = 5
|
||||
npos_max = 10
|
||||
|
||||
cope = ContextualPositionalEncoding(embed_dim=embed_dim, npos_max=npos_max)
|
||||
q = torch.rand(2, 3, npos_max, embed_dim)
|
||||
|
||||
qk = torch.rand(2, 3, npos_max, npos_max)
|
||||
|
||||
p = cope(q=q, qk=qk)
|
||||
print(p.shape)
|
||||
|
||||
|
||||
def main():
|
||||
test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20240703)
|
||||
main()
|
@ -95,6 +95,7 @@ class Zipformer2(EncoderInterface):
|
||||
context chunks for causal training; will be rounded to a number of
|
||||
chunks. Must not be less than cnn_module_kernel (after factoring in
|
||||
rounding and downsampling); an error will be thrown if this is violated.
|
||||
use_cope (bool): If true, use contextual positional encoding
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -116,6 +117,7 @@ class Zipformer2(EncoderInterface):
|
||||
causal: bool = False,
|
||||
chunk_size: Tuple[int] = [-1],
|
||||
left_context_frames: Tuple[int] = [-1],
|
||||
use_cope: bool = False,
|
||||
) -> None:
|
||||
super(Zipformer2, self).__init__()
|
||||
|
||||
@ -183,6 +185,7 @@ class Zipformer2(EncoderInterface):
|
||||
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
||||
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
||||
use_cope=use_cope,
|
||||
)
|
||||
|
||||
if downsampling_factor[i] != 1:
|
||||
@ -1021,6 +1024,7 @@ class Zipformer2Encoder(nn.Module):
|
||||
warmup_end: float,
|
||||
initial_layerdrop_rate: float = 0.5,
|
||||
final_layerdrop_rate: float = 0.05,
|
||||
use_cope: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.encoder_pos = CompactRelPositionalEncoding(
|
||||
@ -1393,6 +1397,87 @@ class SimpleUpsample(torch.nn.Module):
|
||||
return src
|
||||
|
||||
|
||||
class ContextualPositionalEncoding(torch.nn.Module):
|
||||
"""
|
||||
This class implements the following paper:
|
||||
Contextual Position Encoding: Learning to Count What's Important
|
||||
https://arxiv.org/abs/2405.18719
|
||||
|
||||
Args:
|
||||
embed_dim: Embedding dimension.
|
||||
npos_max: The maximum context size.
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim: int, npos_max: int):
|
||||
super().__init__()
|
||||
self.npos_max = npos_max
|
||||
self.embedding = nn.Embedding(
|
||||
num_embeddings=npos_max,
|
||||
embedding_dim=embed_dim,
|
||||
)
|
||||
|
||||
def forward(self, q: torch.Tensor, qk: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
q (torch.Tensor): A tensor of shape (head, batch, time1, query_head_dim)
|
||||
qk (torch.Tensor): A tensor of shape (head, batch, time1, time2)
|
||||
Returns:
|
||||
Return a tensor of shape (head, batch, time1, npos_max)
|
||||
|
||||
Note the implementation assumes time1 == time2 and npos_max <= time2.
|
||||
The implementation is reasonable for the streaming ASR encoder where
|
||||
only self attention is used.
|
||||
"""
|
||||
# The implementation on page 13 Listing 1 from the paper does not use
|
||||
# a mask to ensure that only gates[:, :, i, j] where j < i is computed.
|
||||
#
|
||||
# Here we fix that by introducing a mask
|
||||
mask = torch.triu(
|
||||
torch.full((qk.size(3), qk.size(3)), True, dtype=torch.bool),
|
||||
diagonal=0,
|
||||
)
|
||||
#
|
||||
# if qk.size(3) is 4, mask is
|
||||
#
|
||||
# tensor([[ True, True, True, True],
|
||||
# [False, True, True, True],
|
||||
# [False, False, True, True],
|
||||
# [False, False, False, True]])
|
||||
#
|
||||
# mask[i, j] is True if i >= j
|
||||
gates = torch.sigmoid(qk)
|
||||
|
||||
# We don't use an in-place operation here for the sake of autograd
|
||||
gates = gates.masked_fill(mask, 0)
|
||||
|
||||
# cumsum() is an inclusive sum in PyTorch
|
||||
pos = gates.flip(-1).cumsum(dim=-1).flip(-1) # (head, batch, time1, time2)
|
||||
# pos[:, :, i, j] should be 0 for j >= i
|
||||
# pos[:, :, i, j] contains the position between i and j. If gates
|
||||
# is a 0-1 matrix, then pos[:, :, i, j] equals to i - j (for j < i)
|
||||
# Note: The paper says on page 4 it equals to i - j + 1 instead of i - j.
|
||||
|
||||
pos = pos.clamp(max=self.npos_max - 1)
|
||||
pos_ceil = pos.ceil().long()
|
||||
pos_floor = pos.floor().long()
|
||||
|
||||
# We assume query_head_dim equals to embed_dim
|
||||
|
||||
logits_int = torch.matmul(
|
||||
q, self.embedding.weight.t()
|
||||
) # (head, batch, time1, npos_max)
|
||||
|
||||
# We assume that npos_max <= time2
|
||||
logits_cell = logits_int.gather(-1, pos_ceil)
|
||||
logits_floor = logits_int.gather(-1, pos_floor)
|
||||
|
||||
w = pos - pos_floor
|
||||
|
||||
# Note: The code in the paper on page 13 is correct
|
||||
# while the description on page 4 equation (5) is wrong
|
||||
return logits_cell * w + logits_floor * (1 - w)
|
||||
|
||||
|
||||
class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
"""
|
||||
Relative positional encoding module. This version is "compact" meaning it is able to encode
|
||||
|
Loading…
x
Reference in New Issue
Block a user