From 72dc1780d6fabe224ef00d3a33b6e596e578580b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 2 Aug 2022 20:39:55 +0800 Subject: [PATCH] Fix as_strided for streaming conformer. --- .../ASR/pruned_transducer_stateless2/conformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 00549c086..e95360d1d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -1001,7 +1001,7 @@ class RelPositionMultiheadAttention(nn.Module): """Compute relative positional encoding. Args: - x: Input tensor (batch, head, time1, 2*time1-1). + x: Input tensor (batch, head, time1, 2*time1-1+left_context). time1 means the length of query vector. left_context (int): left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, @@ -1022,13 +1022,13 @@ class RelPositionMultiheadAttention(nn.Module): if torch.jit.is_tracing(): rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(time1) + cols = torch.arange(time2) rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) indexes = rows + cols x = x.reshape(-1, n) x = torch.gather(x, dim=1, index=indexes) - x = x.reshape(batch_size, num_heads, time1, time1) + x = x.reshape(batch_size, num_heads, time1, time2) return x else: # Note: TorchScript requires explicit arg for stride()