mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
changes for the embedding layer in decoder
This commit is contained in:
parent
1603744469
commit
f1b7ab0226
@ -38,6 +38,7 @@ class Decoder(nn.Module):
|
|||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
blank_id: int,
|
blank_id: int,
|
||||||
context_size: int,
|
context_size: int,
|
||||||
|
modify_embedding_layer: bool,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -50,6 +51,8 @@ class Decoder(nn.Module):
|
|||||||
context_size:
|
context_size:
|
||||||
Number of previous words to use to predict the next word.
|
Number of previous words to use to predict the next word.
|
||||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||||
|
modify_embedding_layer:
|
||||||
|
Do or not do modify the decoder embedding layer.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embedding = nn.Embedding(
|
self.embedding = nn.Embedding(
|
||||||
@ -57,6 +60,16 @@ class Decoder(nn.Module):
|
|||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
padding_idx=blank_id,
|
padding_idx=blank_id,
|
||||||
)
|
)
|
||||||
|
self.modify_embedding_layer = modify_embedding_layer
|
||||||
|
if self.modify_embedding_layer:
|
||||||
|
self.embedding.weight = nn.Parameter(
|
||||||
|
(
|
||||||
|
torch.randn(vocab_size, embedding_dim)
|
||||||
|
* (embedding_dim ** -0.5)
|
||||||
|
).detach()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
|
|
||||||
assert context_size >= 1, context_size
|
assert context_size >= 1, context_size
|
||||||
@ -84,6 +97,9 @@ class Decoder(nn.Module):
|
|||||||
Return a tensor of shape (N, U, embedding_dim).
|
Return a tensor of shape (N, U, embedding_dim).
|
||||||
"""
|
"""
|
||||||
embedding_out = self.embedding(y)
|
embedding_out = self.embedding(y)
|
||||||
|
if self.modify_embedding_layer:
|
||||||
|
embedding_out = self.embedding(y) * (self.embedding_dim ** 0.5)
|
||||||
|
|
||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
if need_pad is True:
|
if need_pad is True:
|
||||||
|
@ -179,6 +179,14 @@ def get_parser():
|
|||||||
"with this parameter before adding to the final loss.",
|
"with this parameter before adding to the final loss.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--modify-embedding",
|
||||||
|
type=bool,
|
||||||
|
default=True,
|
||||||
|
help="When True, we modify the decoder embedding layer."
|
||||||
|
"When False, we don't modify the decoder embedding layer.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -284,6 +292,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
embedding_dim=params.embedding_dim,
|
embedding_dim=params.embedding_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
context_size=params.context_size,
|
context_size=params.context_size,
|
||||||
|
modify_embedding_layer=params.modify_embedding,
|
||||||
)
|
)
|
||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user