diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 00549c086..e95360d1d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -1001,7 +1001,7 @@ class RelPositionMultiheadAttention(nn.Module): """Compute relative positional encoding. Args: - x: Input tensor (batch, head, time1, 2*time1-1). + x: Input tensor (batch, head, time1, 2*time1-1+left_context). time1 means the length of query vector. left_context (int): left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, @@ -1022,13 +1022,13 @@ class RelPositionMultiheadAttention(nn.Module): if torch.jit.is_tracing(): rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(time1) + cols = torch.arange(time2) rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) indexes = rows + cols x = x.reshape(-1, n) x = torch.gather(x, dim=1, index=indexes) - x = x.reshape(batch_size, num_heads, time1, time1) + x = x.reshape(batch_size, num_heads, time1, time2) return x else: # Note: TorchScript requires explicit arg for stride()