mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
small fixes
This commit is contained in:
parent
851bb50cb8
commit
1378f833bd
@ -430,7 +430,7 @@ def export_encoder_model_onnx(
|
||||
encoder_model,
|
||||
(x, init_state),
|
||||
encoder_filename,
|
||||
verbose=False,
|
||||
verbose=True,
|
||||
opset_version=opset_version,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
|
||||
@ -1682,12 +1682,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
if torch.jit.is_tracing():
|
||||
(num_heads, batch_size, time1, n) = pos_scores.shape
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(seq_len)
|
||||
cols = torch.arange(k_len)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
pos_scores = pos_scores.reshape(-1, n)
|
||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
|
||||
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
|
||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# not, but let this code define which way round it is supposed to be.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user