mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
fix bug (#796)
This commit is contained in:
parent
05dfd5e630
commit
3c54333b06
@ -966,20 +966,32 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
(batch_size, num_heads, time1, n) = x.shape
|
(batch_size, num_heads, time1, n) = x.shape
|
||||||
|
|
||||||
time2 = time1 + left_context
|
time2 = time1 + left_context
|
||||||
assert (
|
if not torch.jit.is_tracing():
|
||||||
n == left_context + 2 * time1 - 1
|
assert (
|
||||||
), f"{n} == {left_context} + 2 * {time1} - 1"
|
n == left_context + 2 * time1 - 1
|
||||||
|
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||||
|
|
||||||
# Note: TorchScript requires explicit arg for stride()
|
if torch.jit.is_tracing():
|
||||||
batch_stride = x.stride(0)
|
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||||
head_stride = x.stride(1)
|
cols = torch.arange(time2)
|
||||||
time1_stride = x.stride(2)
|
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||||
n_stride = x.stride(3)
|
indexes = rows + cols
|
||||||
return x.as_strided(
|
|
||||||
(batch_size, num_heads, time1, time2),
|
x = x.reshape(-1, n)
|
||||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
x = torch.gather(x, dim=1, index=indexes)
|
||||||
storage_offset=n_stride * (time1 - 1),
|
x = x.reshape(batch_size, num_heads, time1, time2)
|
||||||
)
|
return x
|
||||||
|
else:
|
||||||
|
# Note: TorchScript requires explicit arg for stride()
|
||||||
|
batch_stride = x.stride(0)
|
||||||
|
head_stride = x.stride(1)
|
||||||
|
time1_stride = x.stride(2)
|
||||||
|
n_stride = x.stride(3)
|
||||||
|
return x.as_strided(
|
||||||
|
(batch_size, num_heads, time1, time2),
|
||||||
|
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||||
|
storage_offset=n_stride * (time1 - 1),
|
||||||
|
)
|
||||||
|
|
||||||
def multi_head_attention_forward(
|
def multi_head_attention_forward(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user