change for decoder.py

This commit is contained in:
luomingshuang 2022-03-10 11:29:12 +08:00
parent 18b8c670dd
commit eb48ade752

View File

@ -52,7 +52,7 @@ class Decoder(nn.Module):
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
modify_embedding_layer:
Decide to modify the decoder embedding layer or not.
Whether modify the decoder embedding layer or not.
"""
super().__init__()
self.embedding = nn.Embedding(
@ -61,12 +61,18 @@ class Decoder(nn.Module):
padding_idx=blank_id,
)
self.modify_embedding_layer = modify_embedding_layer
self.embedding_dim_inv_sqrt = embedding_dim ** 0.5
self.embedding_dim_sqrt = embedding_dim ** 0.5
embedding_weight = torch.randn(vocab_size, embedding_dim) * (
1 / self.embedding_dim_sqrt
)
if self.modify_embedding_layer:
embedding_weight = torch.randn(vocab_size, embedding_dim) * (
1 / self.embedding_dim_inv_sqrt
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=blank_id,
_weight=embedding_weight,
)
self.embedding.weight = nn.Parameter(embedding_weight)
self.embedding._fill_padding_idx_with_zero()
self.embedding_dim = embedding_dim
self.blank_id = blank_id
@ -97,7 +103,7 @@ class Decoder(nn.Module):
"""
embedding_out = self.embedding(y)
if self.modify_embedding_layer:
embedding_out = self.embedding(y) * self.embedding_dim_inv_sqrt
embedding_out = self.embedding(y) * self.embedding_dim_sqrt
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)