This commit is contained in:
Yuekai Zhang 2022-12-28 11:20:38 +08:00 committed by GitHub
parent 05dfd5e630
commit 3c54333b06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -966,10 +966,22 @@ class RelPositionMultiheadAttention(nn.Module):
(batch_size, num_heads, time1, n) = x.shape
time2 = time1 + left_context
if not torch.jit.is_tracing():
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
if torch.jit.is_tracing():
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(time2)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
x = x.reshape(-1, n)
x = torch.gather(x, dim=1, index=indexes)
x = x.reshape(batch_size, num_heads, time1, time2)
return x
else:
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)