small fixes

This commit is contained in:
danqing fu 2023-06-08 10:40:15 +08:00
parent 851bb50cb8
commit 1378f833bd
2 changed files with 3 additions and 3 deletions

View File

@ -430,7 +430,7 @@ def export_encoder_model_onnx(
encoder_model,
(x, init_state),
encoder_filename,
verbose=False,
verbose=True,
opset_version=opset_version,
input_names=input_names,
output_names=output_names,

View File

@ -1682,12 +1682,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
if torch.jit.is_tracing():
(num_heads, batch_size, time1, n) = pos_scores.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(seq_len)
cols = torch.arange(k_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
# the following .as_strided() expression converts the last axis of pos_scores from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.