Fix as_strided for streaming conformer.

This commit is contained in:
Fangjun Kuang 2022-08-02 20:39:55 +08:00
parent 6e43a2b69d
commit 72dc1780d6

View File

@ -1001,7 +1001,7 @@ class RelPositionMultiheadAttention(nn.Module):
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, time1, 2*time1-1).
x: Input tensor (batch, head, time1, 2*time1-1+left_context).
time1 means the length of query vector.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
@ -1022,13 +1022,13 @@ class RelPositionMultiheadAttention(nn.Module):
if torch.jit.is_tracing():
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(time1)
cols = torch.arange(time2)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
x = x.reshape(-1, n)
x = torch.gather(x, dim=1, index=indexes)
x = x.reshape(batch_size, num_heads, time1, time1)
x = x.reshape(batch_size, num_heads, time1, time2)
return x
else:
# Note: TorchScript requires explicit arg for stride()