Fix as_strided for streaming conformer.
This commit is contained in:
parent
6e43a2b69d
commit
72dc1780d6
@ -1001,7 +1001,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
"""Compute relative positional encoding.
|
"""Compute relative positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
x: Input tensor (batch, head, time1, 2*time1-1+left_context).
|
||||||
time1 means the length of query vector.
|
time1 means the length of query vector.
|
||||||
left_context (int): left context (in frames) used during streaming decoding.
|
left_context (int): left context (in frames) used during streaming decoding.
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
@ -1022,13 +1022,13 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
if torch.jit.is_tracing():
|
if torch.jit.is_tracing():
|
||||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||||
cols = torch.arange(time1)
|
cols = torch.arange(time2)
|
||||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||||
indexes = rows + cols
|
indexes = rows + cols
|
||||||
|
|
||||||
x = x.reshape(-1, n)
|
x = x.reshape(-1, n)
|
||||||
x = torch.gather(x, dim=1, index=indexes)
|
x = torch.gather(x, dim=1, index=indexes)
|
||||||
x = x.reshape(batch_size, num_heads, time1, time1)
|
x = x.reshape(batch_size, num_heads, time1, time2)
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
# Note: TorchScript requires explicit arg for stride()
|
# Note: TorchScript requires explicit arg for stride()
|
||||||
|
|||||||
Reference in New Issue
Block a user