mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix as_strided for streaming conformer.
This commit is contained in:
parent
6e43a2b69d
commit
72dc1780d6
@ -1001,7 +1001,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||
x: Input tensor (batch, head, time1, 2*time1-1+left_context).
|
||||
time1 means the length of query vector.
|
||||
left_context (int): left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
@ -1022,13 +1022,13 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
if torch.jit.is_tracing():
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(time1)
|
||||
cols = torch.arange(time2)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
|
||||
x = x.reshape(-1, n)
|
||||
x = torch.gather(x, dim=1, index=indexes)
|
||||
x = x.reshape(batch_size, num_heads, time1, time1)
|
||||
x = x.reshape(batch_size, num_heads, time1, time2)
|
||||
return x
|
||||
else:
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
|
Loading…
x
Reference in New Issue
Block a user