mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
fix bug of relative position
This commit is contained in:
parent
630626a092
commit
7b15596495
@ -1540,8 +1540,8 @@ class EmformerEncoder(nn.Module):
|
|||||||
key_indexes = torch.cat(
|
key_indexes = torch.cat(
|
||||||
[memory_indexes, right_context_indexes, utterance_indexes]
|
[memory_indexes, right_context_indexes, utterance_indexes]
|
||||||
).to(device=x.device)
|
).to(device=x.device)
|
||||||
# calculate relative position and flip sign
|
# calculate relative position
|
||||||
rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0))
|
rel_pos = query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0)
|
||||||
# shift to start from zero
|
# shift to start from zero
|
||||||
rel_pos = rel_pos - rel_pos.min()
|
rel_pos = rel_pos - rel_pos.min()
|
||||||
|
|
||||||
@ -1697,8 +1697,8 @@ class EmformerEncoder(nn.Module):
|
|||||||
utterance_indexes,
|
utterance_indexes,
|
||||||
]
|
]
|
||||||
).to(device=x.device)
|
).to(device=x.device)
|
||||||
# calculate relative position and flip sign
|
# calculate relative position
|
||||||
rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0))
|
rel_pos = query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0)
|
||||||
# shift to start from zero
|
# shift to start from zero
|
||||||
rel_pos = rel_pos - rel_pos.min()
|
rel_pos = rel_pos - rel_pos.min()
|
||||||
|
|
||||||
@ -2071,10 +2071,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
pe_positive[:, 0::2] = torch.sin(position_positive * div_term)
|
pe_positive[:, 0::2] = torch.sin(position_positive * div_term)
|
||||||
pe_positive[:, 1::2] = torch.cos(position_positive * div_term)
|
pe_positive[:, 1::2] = torch.cos(position_positive * div_term)
|
||||||
# Reserve the order of positive indices and concat both positive and
|
self.pe_positive = pe_positive
|
||||||
# negative indices. This is used to support the shifting trick
|
|
||||||
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa
|
|
||||||
self.pe_positive = torch.flip(pe_positive, [0])
|
|
||||||
|
|
||||||
def gen_pe_negative(self) -> None:
|
def gen_pe_negative(self) -> None:
|
||||||
"""Generate the negative positional encodings."""
|
"""Generate the negative positional encodings."""
|
||||||
@ -2091,7 +2088,8 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
pe_negative[:, 0::2] = torch.sin(-1 * position_negative * div_term)
|
pe_negative[:, 0::2] = torch.sin(-1 * position_negative * div_term)
|
||||||
pe_negative[:, 1::2] = torch.cos(-1 * position_negative * div_term)
|
pe_negative[:, 1::2] = torch.cos(-1 * position_negative * div_term)
|
||||||
self.pe_negative = pe_negative
|
# Reserve the order of negative indices
|
||||||
|
self.pe_negative = torch.flip(pe_negative, [0])
|
||||||
|
|
||||||
def get_pe(
|
def get_pe(
|
||||||
self,
|
self,
|
||||||
@ -2111,8 +2109,8 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
self.pe_negative = self.pe_negative.to(dtype=dtype, device=device)
|
self.pe_negative = self.pe_negative.to(dtype=dtype, device=device)
|
||||||
pe = torch.cat(
|
pe = torch.cat(
|
||||||
[
|
[
|
||||||
self.pe_positive[self.pos_len - pos_len :],
|
self.pe_negative[self.neg_len - neg_len :],
|
||||||
self.pe_negative[1:neg_len],
|
self.pe_positive[1:pos_len],
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user