This commit is contained in:
yaozengwei 2022-06-26 22:34:40 +08:00
parent 7b15596495
commit 5ea58a4465

View File

@ -2109,6 +2109,8 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe_negative = self.pe_negative.to(dtype=dtype, device=device)
pe = torch.cat(
[
# it starts from the min negative value of relative position
# and it is bound to be gathered
self.pe_negative[self.neg_len - neg_len :],
self.pe_positive[1:pos_len],
],