mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
small fixes
This commit is contained in:
parent
851bb50cb8
commit
1378f833bd
@ -430,7 +430,7 @@ def export_encoder_model_onnx(
|
|||||||
encoder_model,
|
encoder_model,
|
||||||
(x, init_state),
|
(x, init_state),
|
||||||
encoder_filename,
|
encoder_filename,
|
||||||
verbose=False,
|
verbose=True,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
input_names=input_names,
|
input_names=input_names,
|
||||||
output_names=output_names,
|
output_names=output_names,
|
||||||
|
|||||||
@ -1682,12 +1682,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
if torch.jit.is_tracing():
|
if torch.jit.is_tracing():
|
||||||
(num_heads, batch_size, time1, n) = pos_scores.shape
|
(num_heads, batch_size, time1, n) = pos_scores.shape
|
||||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||||
cols = torch.arange(seq_len)
|
cols = torch.arange(k_len)
|
||||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||||
indexes = rows + cols
|
indexes = rows + cols
|
||||||
pos_scores = pos_scores.reshape(-1, n)
|
pos_scores = pos_scores.reshape(-1, n)
|
||||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||||
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
|
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
|
||||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||||
# not, but let this code define which way round it is supposed to be.
|
# not, but let this code define which way round it is supposed to be.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user