mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
fix bug of relative position
This commit is contained in:
parent
630626a092
commit
7b15596495
@ -1540,8 +1540,8 @@ class EmformerEncoder(nn.Module):
|
||||
key_indexes = torch.cat(
|
||||
[memory_indexes, right_context_indexes, utterance_indexes]
|
||||
).to(device=x.device)
|
||||
# calculate relative position and flip sign
|
||||
rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0))
|
||||
# calculate relative position
|
||||
rel_pos = query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0)
|
||||
# shift to start from zero
|
||||
rel_pos = rel_pos - rel_pos.min()
|
||||
|
||||
@ -1697,8 +1697,8 @@ class EmformerEncoder(nn.Module):
|
||||
utterance_indexes,
|
||||
]
|
||||
).to(device=x.device)
|
||||
# calculate relative position and flip sign
|
||||
rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0))
|
||||
# calculate relative position
|
||||
rel_pos = query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0)
|
||||
# shift to start from zero
|
||||
rel_pos = rel_pos - rel_pos.min()
|
||||
|
||||
@ -2071,10 +2071,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
)
|
||||
pe_positive[:, 0::2] = torch.sin(position_positive * div_term)
|
||||
pe_positive[:, 1::2] = torch.cos(position_positive * div_term)
|
||||
# Reserve the order of positive indices and concat both positive and
|
||||
# negative indices. This is used to support the shifting trick
|
||||
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa
|
||||
self.pe_positive = torch.flip(pe_positive, [0])
|
||||
self.pe_positive = pe_positive
|
||||
|
||||
def gen_pe_negative(self) -> None:
|
||||
"""Generate the negative positional encodings."""
|
||||
@ -2091,7 +2088,8 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
)
|
||||
pe_negative[:, 0::2] = torch.sin(-1 * position_negative * div_term)
|
||||
pe_negative[:, 1::2] = torch.cos(-1 * position_negative * div_term)
|
||||
self.pe_negative = pe_negative
|
||||
# Reserve the order of negative indices
|
||||
self.pe_negative = torch.flip(pe_negative, [0])
|
||||
|
||||
def get_pe(
|
||||
self,
|
||||
@ -2111,8 +2109,8 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
self.pe_negative = self.pe_negative.to(dtype=dtype, device=device)
|
||||
pe = torch.cat(
|
||||
[
|
||||
self.pe_positive[self.pos_len - pos_len :],
|
||||
self.pe_negative[1:neg_len],
|
||||
self.pe_negative[self.neg_len - neg_len :],
|
||||
self.pe_positive[1:pos_len],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user