do some changes

This commit is contained in:
luomingshuang 2022-03-10 10:24:23 +08:00
parent f1b7ab0226
commit 18b8c670dd
2 changed files with 8 additions and 9 deletions

View File

@ -38,7 +38,7 @@ class Decoder(nn.Module):
embedding_dim: int, embedding_dim: int,
blank_id: int, blank_id: int,
context_size: int, context_size: int,
modify_embedding_layer: bool, modify_embedding_layer: bool = False,
): ):
""" """
Args: Args:
@ -52,7 +52,7 @@ class Decoder(nn.Module):
Number of previous words to use to predict the next word. Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram. 1 means bigram; 2 means trigram. n means (n+1)-gram.
modify_embedding_layer: modify_embedding_layer:
Do or not do modify the decoder embedding layer. Decide to modify the decoder embedding layer or not.
""" """
super().__init__() super().__init__()
self.embedding = nn.Embedding( self.embedding = nn.Embedding(
@ -61,13 +61,12 @@ class Decoder(nn.Module):
padding_idx=blank_id, padding_idx=blank_id,
) )
self.modify_embedding_layer = modify_embedding_layer self.modify_embedding_layer = modify_embedding_layer
self.embedding_dim_inv_sqrt = embedding_dim ** 0.5
if self.modify_embedding_layer: if self.modify_embedding_layer:
self.embedding.weight = nn.Parameter( embedding_weight = torch.randn(vocab_size, embedding_dim) * (
( 1 / self.embedding_dim_inv_sqrt
torch.randn(vocab_size, embedding_dim)
* (embedding_dim ** -0.5)
).detach()
) )
self.embedding.weight = nn.Parameter(embedding_weight)
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.blank_id = blank_id self.blank_id = blank_id
@ -98,7 +97,7 @@ class Decoder(nn.Module):
""" """
embedding_out = self.embedding(y) embedding_out = self.embedding(y)
if self.modify_embedding_layer: if self.modify_embedding_layer:
embedding_out = self.embedding(y) * (self.embedding_dim ** 0.5) embedding_out = self.embedding(y) * self.embedding_dim_inv_sqrt
if self.context_size > 1: if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)

View File

@ -182,7 +182,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--modify-embedding", "--modify-embedding",
type=bool, type=bool,
default=True, default=False,
help="When True, we modify the decoder embedding layer." help="When True, we modify the decoder embedding layer."
"When False, we don't modify the decoder embedding layer.", "When False, we don't modify the decoder embedding layer.",
) )