diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index ebda2252f..43d775cff 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -430,7 +430,7 @@ def export_encoder_model_onnx( encoder_model, (x, init_state), encoder_filename, - verbose=False, + verbose=True, opset_version=opset_version, input_names=input_names, output_names=output_names, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index cf810a298..612356a50 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1682,12 +1682,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if torch.jit.is_tracing(): (num_heads, batch_size, time1, n) = pos_scores.shape rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(seq_len) + cols = torch.arange(k_len) rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) indexes = rows + cols pos_scores = pos_scores.reshape(-1, n) pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) # the following .as_strided() expression converts the last axis of pos_scores from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be.