fix bug of relative position

This commit is contained in:
yaozengwei 2022-06-26 22:08:40 +08:00
parent 630626a092
commit 7b15596495

View File

@ -1540,8 +1540,8 @@ class EmformerEncoder(nn.Module):
key_indexes = torch.cat(
[memory_indexes, right_context_indexes, utterance_indexes]
).to(device=x.device)
# calculate relative position and flip sign
rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0))
# calculate relative position
rel_pos = query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0)
# shift to start from zero
rel_pos = rel_pos - rel_pos.min()
@ -1697,8 +1697,8 @@ class EmformerEncoder(nn.Module):
utterance_indexes,
]
).to(device=x.device)
# calculate relative position and flip sign
rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0))
# calculate relative position
rel_pos = query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0)
# shift to start from zero
rel_pos = rel_pos - rel_pos.min()
@ -2071,10 +2071,7 @@ class RelPositionalEncoding(torch.nn.Module):
)
pe_positive[:, 0::2] = torch.sin(position_positive * div_term)
pe_positive[:, 1::2] = torch.cos(position_positive * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa
self.pe_positive = torch.flip(pe_positive, [0])
self.pe_positive = pe_positive
def gen_pe_negative(self) -> None:
"""Generate the negative positional encodings."""
@ -2091,7 +2088,8 @@ class RelPositionalEncoding(torch.nn.Module):
)
pe_negative[:, 0::2] = torch.sin(-1 * position_negative * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position_negative * div_term)
self.pe_negative = pe_negative
# Reserve the order of negative indices
self.pe_negative = torch.flip(pe_negative, [0])
def get_pe(
self,
@ -2111,8 +2109,8 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe_negative = self.pe_negative.to(dtype=dtype, device=device)
pe = torch.cat(
[
self.pe_positive[self.pos_len - pos_len :],
self.pe_negative[1:neg_len],
self.pe_negative[self.neg_len - neg_len :],
self.pe_positive[1:pos_len],
],
dim=0,
)