fix bug of relative position

This commit is contained in:
yaozengwei 2022-06-26 22:08:40 +08:00
parent 630626a092
commit 7b15596495

View File

@ -1540,8 +1540,8 @@ class EmformerEncoder(nn.Module):
key_indexes = torch.cat( key_indexes = torch.cat(
[memory_indexes, right_context_indexes, utterance_indexes] [memory_indexes, right_context_indexes, utterance_indexes]
).to(device=x.device) ).to(device=x.device)
# calculate relative position and flip sign # calculate relative position
rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0)) rel_pos = query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0)
# shift to start from zero # shift to start from zero
rel_pos = rel_pos - rel_pos.min() rel_pos = rel_pos - rel_pos.min()
@ -1697,8 +1697,8 @@ class EmformerEncoder(nn.Module):
utterance_indexes, utterance_indexes,
] ]
).to(device=x.device) ).to(device=x.device)
# calculate relative position and flip sign # calculate relative position
rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0)) rel_pos = query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0)
# shift to start from zero # shift to start from zero
rel_pos = rel_pos - rel_pos.min() rel_pos = rel_pos - rel_pos.min()
@ -2071,10 +2071,7 @@ class RelPositionalEncoding(torch.nn.Module):
) )
pe_positive[:, 0::2] = torch.sin(position_positive * div_term) pe_positive[:, 0::2] = torch.sin(position_positive * div_term)
pe_positive[:, 1::2] = torch.cos(position_positive * div_term) pe_positive[:, 1::2] = torch.cos(position_positive * div_term)
# Reserve the order of positive indices and concat both positive and self.pe_positive = pe_positive
# negative indices. This is used to support the shifting trick
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa
self.pe_positive = torch.flip(pe_positive, [0])
def gen_pe_negative(self) -> None: def gen_pe_negative(self) -> None:
"""Generate the negative positional encodings.""" """Generate the negative positional encodings."""
@ -2091,7 +2088,8 @@ class RelPositionalEncoding(torch.nn.Module):
) )
pe_negative[:, 0::2] = torch.sin(-1 * position_negative * div_term) pe_negative[:, 0::2] = torch.sin(-1 * position_negative * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position_negative * div_term) pe_negative[:, 1::2] = torch.cos(-1 * position_negative * div_term)
self.pe_negative = pe_negative # Reserve the order of negative indices
self.pe_negative = torch.flip(pe_negative, [0])
def get_pe( def get_pe(
self, self,
@ -2111,8 +2109,8 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe_negative = self.pe_negative.to(dtype=dtype, device=device) self.pe_negative = self.pe_negative.to(dtype=dtype, device=device)
pe = torch.cat( pe = torch.cat(
[ [
self.pe_positive[self.pos_len - pos_len :], self.pe_negative[self.neg_len - neg_len :],
self.pe_negative[1:neg_len], self.pe_positive[1:pos_len],
], ],
dim=0, dim=0,
) )