changes for the embedding layer in decoder

This commit is contained in:
luomingshuang 2022-03-09 16:49:57 +08:00
parent 1603744469
commit f1b7ab0226
2 changed files with 25 additions and 0 deletions

View File

@ -38,6 +38,7 @@ class Decoder(nn.Module):
embedding_dim: int, embedding_dim: int,
blank_id: int, blank_id: int,
context_size: int, context_size: int,
modify_embedding_layer: bool,
): ):
""" """
Args: Args:
@ -50,6 +51,8 @@ class Decoder(nn.Module):
context_size: context_size:
Number of previous words to use to predict the next word. Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram. 1 means bigram; 2 means trigram. n means (n+1)-gram.
modify_embedding_layer:
Do or not do modify the decoder embedding layer.
""" """
super().__init__() super().__init__()
self.embedding = nn.Embedding( self.embedding = nn.Embedding(
@ -57,6 +60,16 @@ class Decoder(nn.Module):
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
padding_idx=blank_id, padding_idx=blank_id,
) )
self.modify_embedding_layer = modify_embedding_layer
if self.modify_embedding_layer:
self.embedding.weight = nn.Parameter(
(
torch.randn(vocab_size, embedding_dim)
* (embedding_dim ** -0.5)
).detach()
)
self.embedding_dim = embedding_dim
self.blank_id = blank_id self.blank_id = blank_id
assert context_size >= 1, context_size assert context_size >= 1, context_size
@ -84,6 +97,9 @@ class Decoder(nn.Module):
Return a tensor of shape (N, U, embedding_dim). Return a tensor of shape (N, U, embedding_dim).
""" """
embedding_out = self.embedding(y) embedding_out = self.embedding(y)
if self.modify_embedding_layer:
embedding_out = self.embedding(y) * (self.embedding_dim ** 0.5)
if self.context_size > 1: if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:

View File

@ -179,6 +179,14 @@ def get_parser():
"with this parameter before adding to the final loss.", "with this parameter before adding to the final loss.",
) )
parser.add_argument(
"--modify-embedding",
type=bool,
default=True,
help="When True, we modify the decoder embedding layer."
"When False, we don't modify the decoder embedding layer.",
)
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=int, type=int,
@ -284,6 +292,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
embedding_dim=params.embedding_dim, embedding_dim=params.embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
context_size=params.context_size, context_size=params.context_size,
modify_embedding_layer=params.modify_embedding,
) )
return decoder return decoder