diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless3/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless3/emformer.py index 06fc880df..c8e202bac 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless3/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless3/emformer.py @@ -1540,8 +1540,8 @@ class EmformerEncoder(nn.Module): key_indexes = torch.cat( [memory_indexes, right_context_indexes, utterance_indexes] ).to(device=x.device) - # calculate relative position and flip sign - rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0)) + # calculate relative position + rel_pos = query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0) # shift to start from zero rel_pos = rel_pos - rel_pos.min() @@ -1697,8 +1697,8 @@ class EmformerEncoder(nn.Module): utterance_indexes, ] ).to(device=x.device) - # calculate relative position and flip sign - rel_pos = -(query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0)) + # calculate relative position + rel_pos = query_indexes.unsqueeze(1) - key_indexes.unsqueeze(0) # shift to start from zero rel_pos = rel_pos - rel_pos.min() @@ -2071,10 +2071,7 @@ class RelPositionalEncoding(torch.nn.Module): ) pe_positive[:, 0::2] = torch.sin(position_positive * div_term) pe_positive[:, 1::2] = torch.cos(position_positive * div_term) - # Reserve the order of positive indices and concat both positive and - # negative indices. This is used to support the shifting trick - # as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa - self.pe_positive = torch.flip(pe_positive, [0]) + self.pe_positive = pe_positive def gen_pe_negative(self) -> None: """Generate the negative positional encodings.""" @@ -2091,7 +2088,8 @@ class RelPositionalEncoding(torch.nn.Module): ) pe_negative[:, 0::2] = torch.sin(-1 * position_negative * div_term) pe_negative[:, 1::2] = torch.cos(-1 * position_negative * div_term) - self.pe_negative = pe_negative + # Reserve the order of negative indices + self.pe_negative = torch.flip(pe_negative, [0]) def get_pe( self, @@ -2111,8 +2109,8 @@ class RelPositionalEncoding(torch.nn.Module): self.pe_negative = self.pe_negative.to(dtype=dtype, device=device) pe = torch.cat( [ - self.pe_positive[self.pos_len - pos_len :], - self.pe_negative[1:neg_len], + self.pe_negative[self.neg_len - neg_len :], + self.pe_positive[1:pos_len], ], dim=0, )